Skip to content

Commit

Permalink
evaluation in UT
Browse files Browse the repository at this point in the history
  • Loading branch information
dpomianbuff committed Sep 18, 2023
1 parent 460f272 commit 985fd7b
Show file tree
Hide file tree
Showing 32 changed files with 4,726 additions and 286 deletions.
42 changes: 0 additions & 42 deletions evaluation/.gitignore

This file was deleted.

42 changes: 0 additions & 42 deletions evaluation/build.gradle.kts

This file was deleted.

16 changes: 16 additions & 0 deletions src/main/kotlin/com/intellij/ml/llm/template/evaluation/Oracle.kt
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
package com.intellij.ml.llm.template.evaluation

data class OracleData (
val hostFunctionData: HostFunctionData,
val loc: Int,
val lineStart: Int,
val lineEnd: Int,
val filename: String,
)

data class HostFunctionData (
val lineStart: Int,
val lineEnd: Int,
val bodyLoc: Int,
val githubUrl: String = ""
)
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,12 @@ data class EFCandidate(
@SerializedName("heat")
var heat: Int = -1

@SerializedName("overlap")
var overlap: Int = 0

@SerializedName("heuristic")
var heuristic = ""

@SerializedName("length")
val length = lineEnd - lineStart + 1

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@ package com.intellij.ml.llm.template.extractfunction

enum class EFSettingType(s: String) {
IF_BLOCK_HEURISTIC("DIALOG"),
PREV_ASSIGNMENT_HEURISTIC("PARENT_CLASS")
PREV_ASSIGNMENT_HEURISTIC("PARENT_CLASS"),
VERY_LARGE_BLOCK_HEURISTIC("VERY_LARGE_BLOCK"),
MULTISHOT_LEARNING("MULTISHOT_LEARNING")
}

class EFSettings {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,16 @@ package com.intellij.ml.llm.template.intentions
import com.intellij.codeInsight.intention.IntentionAction
import com.intellij.codeInsight.unwrap.ScopeHighlighter
import com.intellij.ml.llm.template.LLMBundle
import com.intellij.ml.llm.template.evaluation.HostFunctionData
import com.intellij.ml.llm.template.extractfunction.EFCandidate
import com.intellij.ml.llm.template.extractfunction.EFSettingType
import com.intellij.ml.llm.template.extractfunction.EFSettings
import com.intellij.ml.llm.template.extractfunction.EFSuggestion
import com.intellij.ml.llm.template.models.GPTExtractFunctionRequestProvider
import com.intellij.ml.llm.template.models.LLMBaseResponse
import com.intellij.ml.llm.template.models.LLMRequestProvider
import com.intellij.ml.llm.template.models.sendChatRequest
import com.intellij.ml.llm.template.prompts.fewShotExtractSuggestion
import com.intellij.ml.llm.template.models.LlmMultishotResponseData
import com.intellij.ml.llm.template.models.MultishotSender
import com.intellij.ml.llm.template.prompts.multishotExtractFunctionPrompt
import com.intellij.ml.llm.template.telemetry.*
import com.intellij.ml.llm.template.ui.ExtractFunctionPanel
import com.intellij.ml.llm.template.utils.*
Expand Down Expand Up @@ -41,10 +45,16 @@ abstract class ApplyExtractFunctionTransformationIntention(
private val codeTransformer = CodeTransformer()
private val telemetryDataManager = EFTelemetryDataManager()
private var llmResponseTime = 0L
private var hostFunctionData = HostFunctionData(-1, -1, -1)

init {
codeTransformer.addObserver(EFLoggerObserver(logger))
codeTransformer.addObserver(TelemetryDataObserver())
EFSettings.instance
.add(EFSettingType.IF_BLOCK_HEURISTIC)
.add(EFSettingType.MULTISHOT_LEARNING)
.add(EFSettingType.PREV_ASSIGNMENT_HEURISTIC)
.add(EFSettingType.VERY_LARGE_BLOCK_HEURISTIC)
}

override fun getFamilyName(): String = LLMBundle.message("intentions.apply.transformation.family.name")
Expand All @@ -66,6 +76,8 @@ abstract class ApplyExtractFunctionTransformationIntention(
val startLineNumber = editor.document.getLineNumber(selectionModel.selectionStart) + 1
val withLineNumbers = addLineNumbersToCodeSnippet(codeSnippet, startLineNumber)

hostFunctionData = HostFunctionData(startLineNumber, startLineNumber + codeSnippet.lines().size - 1, 0)

telemetryDataManager.addHostFunctionTelemetryData(
EFTelemetryDataUtils.buildHostFunctionTelemetryData(
codeSnippet = codeSnippet,
Expand All @@ -81,30 +93,46 @@ abstract class ApplyExtractFunctionTransformationIntention(

private fun invokeLlm(text: String, project: Project, editor: Editor, file: PsiFile) {
logger.info("Invoking LLM with text: $text")
val messageList = fewShotExtractSuggestion(text)
val messageList = multishotExtractFunctionPrompt(text)

val task = object : Task.Backgroundable(
project, LLMBundle.message("intentions.request.extract.function.background.process.title")
) {
override fun run(indicator: ProgressIndicator) {
val now = System.nanoTime()
val response = sendChatRequest(
project, messageList, efLLMRequestProvider.chatModel, efLLMRequestProvider
)
if (response != null) {
val responseList = MultishotSender(efLLMRequestProvider, project).sendRequest(text, emptyList(), 5, 1.0)
if (responseList.isNotEmpty()) {
invokeLater {
llmResponseTime = System.nanoTime() - now
if (response.getSuggestions().isEmpty()) {
showEFNotification(
project,
LLMBundle.message("notification.extract.function.with.llm.no.suggestions.message"),
NotificationType.INFORMATION
)
} else {
processLLMResponse(response, project, editor, file)
}
llmResponseTime = responseList.sumOf { it.processingTime }
processLLMResponse(responseList, project, editor, file)
}
}
else {
showEFNotification(
project,
LLMBundle.message("notification.extract.function.with.llm.no.suggestions.message"),
NotificationType.INFORMATION
)
}
// val response = sendChatRequest(
// project = project,
// messages = messageList,
// model = efLLMRequestProvider.chatModel
// )
// if (response != null) {
// invokeLater {
// llmResponseTime = System.nanoTime() - now
// if (response.getSuggestions().isEmpty()) {
// showEFNotification(
// project,
// LLMBundle.message("notification.extract.function.with.llm.no.suggestions.message"),
// NotificationType.INFORMATION
// )
// } else {
// processLLMResponse(response, project, editor, file)
// }
// }
// }
}
}
ProgressManager.getInstance().runProcessWithProgressAsynchronously(task, BackgroundableProcessIndicator(task))
Expand All @@ -125,12 +153,17 @@ abstract class ApplyExtractFunctionTransformationIntention(
return filteredCandidates
}

private fun processLLMResponse(response: LLMBaseResponse, project: Project, editor: Editor, file: PsiFile) {
private fun processLLMResponse(llmResponseData: List<LlmMultishotResponseData>, project: Project, editor: Editor, file: PsiFile) {
val now = System.nanoTime()
val efSuggestionList = mutableListOf<EFSuggestion>()

llmResponseData.filter { it.llmResponse != null }
llmResponseData.forEach {
efSuggestionList.addAll(identifyExtractFunctionSuggestions(it.llmResponse!!.getSuggestions()[0].text).suggestionList)
}

val llmResponse = response.getSuggestions()[0]
val efSuggestionList = identifyExtractFunctionSuggestions(llmResponse.text)
val builtCandidates = EFCandidateFactory().buildCandidates(efSuggestionList.suggestionList, editor, file).toList()
var builtCandidates = EFCandidateFactory().buildDistinctCandidates(efSuggestionList, editor, file).toList()
builtCandidates = EFCandidateUtils.rankByHeat(builtCandidates, hostFunctionData)
val candidates = builtCandidates.distinct()
if (candidates.isEmpty()) {
showEFNotification(
Expand All @@ -147,7 +180,7 @@ abstract class ApplyExtractFunctionTransformationIntention(

telemetryDataManager.addCandidatesTelemetryData(
buildCandidatesTelemetryData(
efSuggestionList.suggestionList.size,
efSuggestionList.size,
candidatesApplicationTelemetryObserver.getData()
)
)
Expand All @@ -161,7 +194,8 @@ abstract class ApplyExtractFunctionTransformationIntention(
)
sendTelemetryData()
} else {
showExtractFunctionPopup(project, editor, file, filteredCandidates, codeTransformer)
val top3Candidates = filteredCandidates.subList(0, minOf(3, filteredCandidates.size))
showExtractFunctionPopup(project, editor, file, top3Candidates, codeTransformer)
}
}
}
Expand Down Expand Up @@ -194,6 +228,7 @@ abstract class ApplyExtractFunctionTransformationIntention(
.setRequestFocus(true)
.setTitle(LLMBundle.message("ef.candidates.popup.title"))
.setResizable(true)
.setCancelOnClickOutside(false)
.setMovable(true).createPopup()

// Add onClosed listener
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
package com.intellij.ml.llm.template.models

import com.google.gson.annotations.SerializedName
import com.intellij.ml.llm.template.extractfunction.EFSettingType
import com.intellij.ml.llm.template.extractfunction.EFSettings
import com.intellij.ml.llm.template.prompts.fewShotExtractSuggestion
import com.intellij.ml.llm.template.prompts.multishotExtractFunctionPrompt
import com.intellij.openapi.project.Project
import java.util.concurrent.TimeUnit

Expand All @@ -17,17 +20,26 @@ data class LlmMultishotResponseData(
)

class MultishotSender(val llmRequestProvider: LLMRequestProvider, val project: Project) {
fun sendRequest(data: String, maxShots: Int) : List<LlmMultishotResponseData> {
fun sendRequest(
data: String,
existingShots: List<Int>,
maxShots: Int,
temperature: Double? = null
): List<LlmMultishotResponseData> {
val result = mutableListOf<LlmMultishotResponseData>()

// get prompt
val messageList = fewShotExtractSuggestion(data)
val messageList = if (EFSettings.instance.has(EFSettingType.MULTISHOT_LEARNING)) multishotExtractFunctionPrompt(data) else fewShotExtractSuggestion(data)
val missingShots = getMissingShots(existingShots, maxShots)

for (shotNo in 1..maxShots) {
for (shotNo in missingShots) {
val startTime = System.nanoTime()
// send request
val llmResponse = sendChatRequest(
project, messageList, llmRequestProvider.chatModel, llmRequestProvider
project = project,
messages = messageList,
model = llmRequestProvider.chatModel,
temperature = temperature
)

val processingTime = TimeUnit.NANOSECONDS.toMillis(System.nanoTime() - startTime)
Expand All @@ -36,4 +48,9 @@ class MultishotSender(val llmRequestProvider: LLMRequestProvider, val project: P

return result.sortedBy { it.shotNo }
}

private fun getMissingShots(existingShots: List<Int>, maxShots: Int): List<Int> {
val shots = (1..maxShots).toList()
return shots.subtract(existingShots).toList()
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -92,13 +92,16 @@ class MyMethodExtractor(private val functionNameProvider: FunctionNameProvider?
computeWithAnalyzeProgress<List<ExtractOptions>, ExtractException>(file.project) {
findAllOptionsToExtract(elements)
}
val filteredOptionsToExtract = if (extractFunctionType == MethodExtractionType.PARENT_CLASS) {
allOptionsToExtract.filter { option ->
option.anchor !is PsiClass

var filteredOptionsToExtract = allOptionsToExtract

if (extractFunctionType == MethodExtractionType.PARENT_CLASS) {
filteredOptionsToExtract = filteredOptionsToExtract.filter { it.anchor !is PsiClass }
if (filteredOptionsToExtract.size > 1) {
filteredOptionsToExtract = filteredOptionsToExtract.subList(0, 1)
}
} else {
allOptionsToExtract
}

return selectOptionWithTargetClass(editor, filteredOptionsToExtract)
} catch (e: ExtractException) {
val message = JavaRefactoringBundle.message("extract.method.error.prefix") + " " + (e.message ?: "")
Expand Down Expand Up @@ -146,7 +149,7 @@ class MyMethodExtractor(private val functionNameProvider: FunctionNameProvider?
private fun suggestSafeMethodNames(options: ExtractOptions): List<String> {
var unsafeNames = guessMethodName(options)
if (functionNameProvider != null) {
unsafeNames = listOf(functionNameProvider.getFunctionName().name) + unsafeNames
unsafeNames = listOf(functionNameProvider.getFunctionName().name).filter { it.isNotEmpty() } + unsafeNames
}
val safeNames = unsafeNames.filterNot { name -> hasConflicts(options.copy(methodName = name)) }
if (safeNames.isNotEmpty()) return safeNames
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -71,12 +71,14 @@ fun sendChatRequest(
project: Project,
messages: List<OpenAiChatMessage>,
model: String? = null,
temperature: Double? = null,
llmRequestProvider: LLMRequestProvider = GPTRequestProvider
): LLMBaseResponse? {
val request = llmRequestProvider.createChatGPTRequest(
OpenAiChatRequestBody(
model = model ?: llmRequestProvider.chatModel,
messages = messages
messages = messages,
temperature = temperature
)
)
return sendRequest(project, request)
Expand Down
Loading

0 comments on commit 985fd7b

Please sign in to comment.