From 45a75a1ec357cab821db66f89c3638aaf4a44e55 Mon Sep 17 00:00:00 2001 From: mahaloz Date: Mon, 16 Sep 2024 21:18:35 -0700 Subject: [PATCH 1/6] Add 2 new prompts for finding vulns and recovering man pages --- dailalib/__init__.py | 5 +- dailalib/api/ai_api.py | 11 +- dailalib/api/litellm/prompts/__init__.py | 36 ++- dailalib/api/litellm/prompts/cot_prompts.py | 230 ++++++++++++++++++ .../api/litellm/prompts/few_shot_prompts.py | 212 ++++++++++++++++ dailalib/api/litellm/prompts/prompt.py | 117 +++++++-- setup.cfg | 2 +- 7 files changed, 583 insertions(+), 30 deletions(-) diff --git a/dailalib/__init__.py b/dailalib/__init__.py index 7b10e2a..11e6a77 100644 --- a/dailalib/__init__.py +++ b/dailalib/__init__.py @@ -1,4 +1,4 @@ -__version__ = "3.7.1" +__version__ = "3.8.0" from .api import AIAPI, LiteLLMAIAPI from libbs.api import DecompilerInterface @@ -13,10 +13,9 @@ def create_plugin(*args, **kwargs): litellm_api = LiteLLMAIAPI(delay_init=True) # create context menus for prompts gui_ctx_menu_actions = { - f"DAILA/LLM/{prompt_name}": (prompt.desc, lambda *x, **y: getattr(litellm_api, prompt_name)(*x, **y)) + f"DAILA/LLM/{prompt_name}": (prompt.desc, getattr(litellm_api, prompt_name)) for prompt_name, prompt in litellm_api.prompts_by_name.items() } - # create context menus for others gui_ctx_menu_actions["DAILA/LLM/update_api_key"] = ("Update API key...", litellm_api.ask_api_key) gui_ctx_menu_actions["DAILA/LLM/update_pmpt_style"] = ("Change prompt style...", litellm_api.ask_prompt_style) gui_ctx_menu_actions["DAILA/LLM/update_model"] = ("Change model...", litellm_api.ask_model) diff --git a/dailalib/api/ai_api.py b/dailalib/api/ai_api.py index 81befe8..5146409 100644 --- a/dailalib/api/ai_api.py +++ b/dailalib/api/ai_api.py @@ -80,6 +80,8 @@ def _requires_function(*args, ai_api: "AIAPI" = None, **kwargs): dec_text = kwargs.pop("dec_text", None) use_dec = kwargs.pop("use_dec", True) has_self = kwargs.pop("has_self", True) + number_lines = kwargs.pop("number_lines", False) + context = kwargs.pop("context", None) # make the self object the new AI API, should only be used inside an AIAPI class if not ai_api and has_self: ai_api = args[0] @@ -95,7 +97,9 @@ def _requires_function(*args, ai_api: "AIAPI" = None, **kwargs): # we must have a UI if we have no func if function is None: - function = ai_api._dec_interface.functions[ai_api._dec_interface.gui_active_context().func_addr] + if context is None: + context = ai_api._dec_interface.gui_active_context() + function = ai_api._dec_interface.functions[context.func_addr] # get new text with the function that is present if dec_text is None: @@ -105,6 +109,11 @@ def _requires_function(*args, ai_api: "AIAPI" = None, **kwargs): dec_text = decompilation.text + if number_lines: + # put a number in front of each line + dec_lines = dec_text.split("\n") + dec_text = "\n".join([f"{i + 1} {line}" for i, line in enumerate(dec_lines)]) + return f(*args, function=function, dec_text=dec_text, use_dec=use_dec, **kwargs) return _requires_function diff --git a/dailalib/api/litellm/prompts/__init__.py b/dailalib/api/litellm/prompts/__init__.py index d146d8a..fde109e 100644 --- a/dailalib/api/litellm/prompts/__init__.py +++ b/dailalib/api/litellm/prompts/__init__.py @@ -7,43 +7,53 @@ class PromptNames: RENAME_VARS = "RENAME_VARIABLES" SUMMARIZE_FUNC = "SUMMARIZE_FUNCTION" ID_SRC = "IDENTIFY_SOURCE" + FIND_VULN = "FIND_VULN" + MAN_PAGE = "MAN_PAGE" def get_prompt_template(prompt_name, prompt_style): if prompt_style in [PromptType.FEW_SHOT, PromptType.ZERO_SHOT]: - from .few_shot_prompts import RENAME_FUNCTION, RENAME_VARIABLES, SUMMARIZE_FUNCTION, IDENTIFY_SOURCE + from .few_shot_prompts import ( + RENAME_FUNCTION, RENAME_VARIABLES, SUMMARIZE_FUNCTION, IDENTIFY_SOURCE, FIND_VULN, MAN_PAGE + ) d = { PromptNames.RENAME_FUNC: RENAME_FUNCTION, PromptNames.RENAME_VARS: RENAME_VARIABLES, PromptNames.SUMMARIZE_FUNC: SUMMARIZE_FUNCTION, - PromptNames.ID_SRC: IDENTIFY_SOURCE + PromptNames.ID_SRC: IDENTIFY_SOURCE, + PromptNames.FIND_VULN: FIND_VULN, + PromptNames.MAN_PAGE: MAN_PAGE, } elif prompt_style == PromptType.COT: - from .cot_prompts import RENAME_FUNCTION, RENAME_VARIABLES, SUMMARIZE_FUNCTION, IDENTIFY_SOURCE + from .cot_prompts import ( + RENAME_FUNCTION, RENAME_VARIABLES, SUMMARIZE_FUNCTION, IDENTIFY_SOURCE, FIND_VULN, MAN_PAGE + ) d = { PromptNames.RENAME_FUNC: RENAME_FUNCTION, PromptNames.RENAME_VARS: RENAME_VARIABLES, PromptNames.SUMMARIZE_FUNC: SUMMARIZE_FUNCTION, - PromptNames.ID_SRC: IDENTIFY_SOURCE + PromptNames.ID_SRC: IDENTIFY_SOURCE, + PromptNames.FIND_VULN: FIND_VULN, + PromptNames.MAN_PAGE: MAN_PAGE, } else: raise ValueError("Invalid prompt style") - return d[prompt_name] + return d.get(prompt_name, None) PROMPTS = [ Prompt( "summarize", PromptNames.SUMMARIZE_FUNC, - desc="Summarize the function", + desc="Summarize this function", response_key="summary", gui_result_callback=Prompt.comment_function ), Prompt( "identify_source", PromptNames.ID_SRC, - desc="Identify the source of the function", + desc="Identify the source of this function", response_key="link", gui_result_callback=Prompt.comment_function ), @@ -59,4 +69,16 @@ def get_prompt_template(prompt_name, prompt_style): desc="Suggest a function name", gui_result_callback=Prompt.rename_function ), + Prompt( + "find_vulnerabilities", + PromptNames.FIND_VULN, + desc="Find vulnerabilities in this function", + gui_result_callback=Prompt.comment_vulnerability + ), + Prompt( + "man_page", + PromptNames.MAN_PAGE, + desc="Summarize library call man page", + gui_result_callback=Prompt.comment_man_page + ), ] diff --git a/dailalib/api/litellm/prompts/cot_prompts.py b/dailalib/api/litellm/prompts/cot_prompts.py index a3a92bb..e571778 100644 --- a/dailalib/api/litellm/prompts/cot_prompts.py +++ b/dailalib/api/litellm/prompts/cot_prompts.py @@ -309,3 +309,233 @@ You respond with: """ + + +FIND_VULN = f""" +{COT_PREAMBLE} +All experts will be asked to identify vulnerabilities or bugs in code. When given code, you identify +vulnerabilities and specify the type of vulnerability. Only identify the MOST important vulnerabilities in the code. +Ignore bugs like resource leaks. +{COT_MIDAMBLE} +The question is how to identify vulnerabilities in the code given all the information we got. +Note that the vulnerabilities must be specific and include the line numbers where they occur. If you are unsure +of the vulnerability, please do not guess. +{COT_POSTAMBLE} +""" + """ +# Example +Here is an example. Given the following code: +``` +1 int __fastcall __noreturn main(int argc, const char **argv, const char **envp) +2 { +3 Human *v3; // rbx +4 __int64 v4; // rdx +5 Human *v5; // rbx +6 int v6; // eax +7 __int64 v7; // rax +8 Human *v8; // rbx +9 Human *v9; // rbx +10 char v10[16]; // [rsp+10h] [rbp-50h] BYREF +11 char v11[8]; // [rsp+20h] [rbp-40h] BYREF +12 Human *v12; // [rsp+28h] [rbp-38h] +13 Human *v13; // [rsp+30h] [rbp-30h] +14 size_t nbytes; // [rsp+38h] [rbp-28h] +15 void *buf; // [rsp+40h] [rbp-20h] +16 int v16; // [rsp+48h] [rbp-18h] BYREF +17 char v17; // [rsp+4Eh] [rbp-12h] BYREF +18 char v18[17]; // [rsp+4Fh] [rbp-11h] BYREF +19 +20 std::allocator::allocator(&v17, argv, envp); +21 std::string::string(v10, "Jack", &v17); +22 v3 = (Human *)operator new(0x18uLL); +23 Man::Man(v3, v10, 25LL); +24 v12 = v3; +25 std::string::~string((std::string *)v10); +26 std::allocator::~allocator(&v17); +27 std::allocator::allocator(v18, v10, v4); +28 std::string::string(v11, "Jill", v18); +29 v5 = (Human *)operator new(0x18uLL); +30 Woman::Woman(v5, v11, 21LL); +31 v13 = v5; +32 std::string::~string((std::string *)v11); +33 std::allocator::~allocator(v18); +34 while ( 1 ) +35 { +36 while ( 1 ) +37 { +38 while ( 1 ) +39 { +40 std::operator<<>(&std::cout, "1. use +41 2. after +42 3. free +43 "); +44 std::istream::operator>>(&std::cin, &v16); +45 if ( v16 != 2 ) +46 break; +47 nbytes = atoi(argv[1]); +48 buf = (void *)operator new[](nbytes); +49 v6 = open(argv[2], 0); +50 read(v6, buf, nbytes); +51 v7 = std::operator<<>(&std::cout, "your data is allocated"); +52 std::ostream::operator<<(v7, &std::endl>); +53 } +54 if ( v16 == 3 ) +55 break; +56 if ( v16 == 1 ) +57 { +58 (*(void (__fastcall **)(Human *))(*(_QWORD *)v12 + 8LL))(v12); +59 (*(void (__fastcall **)(Human *))(*(_QWORD *)v13 + 8LL))(v13); +60 } +61 } +62 v8 = v12; +63 if ( v12 ) +64 { +65 Human::~Human(v12); +66 operator delete(v8); +67 } +68 v9 = v13; +69 if ( v13 ) +70 { +71 Human::~Human(v13); +72 operator delete(v9); +73 } +74 } +75 } +``` + +You would respond with: +## Reasoning +### Expert 1: C/C++ Programming Expert +**Assessment**: The first vulnerability I notice is a potential use-after-free. Specifically, in lines 62-73, there are +deletions of the `v12` and `v13` objects. If the program re-enters this loop and tries to access these pointers without +proper reallocation, it will result in undefined behavior due to accessing freed memory. Additionally, lines 47-50 have +potential for buffer overflow. The size `nbytes` from `argv[1]` is used directly without any checks. If `argv[1]` is a +very large value, it can cause excessive allocation or even wrap around to a small value, potentially leading to an +overflow when reading data. + +### Expert 2: Reverse Engineering Expert +**Assessment**: One main issue is the use-after-free vulnerability in lines 62-73. Freeing `v12` and `v13` and then +potentially accessing them in subsequent iterations is problematic. This vulnerability can be exploited to crash the +program or execute arbitrary code. The second notable vulnerability is the insecure handling of `nbytes` in lines +47-50. Without validation, there's a risk that this unbounded value could lead to buffer overflow or memory corruption, +especially if `argv[1]` holds a negative or excessively large number. + +### Expert 3: Cybersecurity Analyst +**Assessment**: The use-after-free in lines 62-73 stands out as particularly severe. If the pointers `v12` and `v13` +are accessed after being freed, it can lead to security breaches or application crashes. Another critical point is the +lack of validation for `nbytes` in lines 47-50, which can potentially cause buffer overflow. This lack of sanitization +makes the application prone to memory corruption, which can be a severe security issue and possibly exploitable. + +## Answer +{ + "vulnerabilities": ["use-after-free (62-73)", "buffer-overflow (47-50)"] + "description": "The code contains a classic use-after-free vulnerability. In lines 62-73, the pointers v12 and v13 (which point to objects of type Human) are deleted (freed) using operator delete. If the program's loop (lines 34-74) executes again and the pointers v12 or v13 are accessed without reallocation, it results in undefined behavior due to use-after-free. In lines 47-50, the code reads a size value from argv[1] and uses it directly with operator new[] to allocate a buffer (buf). There are no checks to ensure that nbytes is a reasonable size, potentially leading to a large allocation or integer overflow." +} + +# Example +Given the following code: +``` +{{ decompilation }} +``` + +You respond with: +""" + +MAN_PAGE = f""" +{COT_PREAMBLE} +All experts will be asked to write a summarized man page for a function in a decompiled C code. +These summaries should include arg and return information as well as types. +The focal point will be on a function call (that is a library) inside this function. +{COT_MIDAMBLE} +The question is how to write a summarized man page for the target function given all the information we got. +A focal line will be given to do analysis on. +{COT_POSTAMBLE} +""" + """ +# Example +Here is an example, given the following code as context: +``` +void __fastcall gz_error(__int64 a1, int a2, const char *a3) +{ + void *v5; // rcx + __int64 v7; // rbx + __int64 v8; // rax + __int64 v9; // rcx + char *v10; // rax + char *v11; // rcx + const char *v12; // r9 + __int64 v13; // rax + + v5 = *(void **)(a1 + 120); + if ( v5 ) + { + if ( *(_DWORD *)(a1 + 116) != -4 ) + free(v5); + *(_QWORD *)(a1 + 120) = 0LL; + } + if ( a2 && a2 != -5 ) + *(_DWORD *)a1 = 0; + *(_DWORD *)(a1 + 116) = a2; + if ( a3 && a2 != -4 ) + { + v7 = -1LL; + v8 = -1LL; + do + ++v8; + while ( *(_BYTE *)(*(_QWORD *)(a1 + 32) + v8) ); + v9 = -1LL; + do + ++v9; + while ( a3[v9] ); + v10 = (char *)malloc(v8 + 3 + v9); + *(_QWORD *)(a1 + 120) = v10; + v11 = v10; + if ( v10 ) + { + v12 = *(const char **)(a1 + 32); + v13 = -1LL; + while ( v12[++v13] != 0 ) + ; + do + ++v7; + while ( a3[v7] ); + snprintf(v11, v7 + v13 + 3, "%s%s%s", v12, ": ", a3); + } + else + { + *(_DWORD *)(a1 + 116) = -4; + } + } +} +``` + +You focus on the line in the above text: +``` + snprintf(v11, v7 + v13 + 3, "%s%s%s", v12, ": ", a3); +``` + +Focusing on the outermost function call in this line, you respond with: +## Reasoning +### Expert 1: C Programming Expert +### Expert 2: Reverse Engineering Expert +### Expert 3: Cybersecurity Analyst + +## Answer +{ + "function": "snprintf", + "args": ["str (char *)", "size (size_t)", "format (const char *)", "..."], + "return": "int", + "description": "The snprintf() function formats and stores a series of characters and values in the array buffer. It is similar to printf(), but with two major differences: it outputs to a buffer rather than stdout, and it takes an additional size parameter specifying the limit of characters to write. The size parameter prevents buffer overflows. It returns the number of characters that would have been written if the buffer was sufficiently large, not counting the terminating null character." +} + +# Example +Given the following code as context: +``` +{{ decompilation }} +``` + +You focus on the line in the above text: +``` +{{ line_text }} +``` + +Focusing on the outermost function call in this line, you respond with: +""" diff --git a/dailalib/api/litellm/prompts/few_shot_prompts.py b/dailalib/api/litellm/prompts/few_shot_prompts.py index 63a9bab..6b38560 100644 --- a/dailalib/api/litellm/prompts/few_shot_prompts.py +++ b/dailalib/api/litellm/prompts/few_shot_prompts.py @@ -166,4 +166,216 @@ ``` You respond with: +""" + +FIND_VULN = """ +# Task +You are a decompiled C expert that identifies vulnerabilities or bugs in code. When given code, you identify +vulnerabilities and specify the type of vulnerability. Only identify the MOST important vulnerabilities in the code. +Ignore bugs like resource leaks. + +You eventually respond with a valid json. As an example: +## Answer +{ + "vulnerabilities": ["command-injection (10-11)"], + "description": "The function is vulnerable to a command injection on line 10 in the call to system." +} + +{% if few_shot %} +# Example +Here is an example. Given the following code: +``` +1 int __fastcall __noreturn main(int argc, const char **argv, const char **envp) +2 { +3 Human *v3; // rbx +4 __int64 v4; // rdx +5 Human *v5; // rbx +6 int v6; // eax +7 __int64 v7; // rax +8 Human *v8; // rbx +9 Human *v9; // rbx +10 char v10[16]; // [rsp+10h] [rbp-50h] BYREF +11 char v11[8]; // [rsp+20h] [rbp-40h] BYREF +12 Human *v12; // [rsp+28h] [rbp-38h] +13 Human *v13; // [rsp+30h] [rbp-30h] +14 size_t nbytes; // [rsp+38h] [rbp-28h] +15 void *buf; // [rsp+40h] [rbp-20h] +16 int v16; // [rsp+48h] [rbp-18h] BYREF +17 char v17; // [rsp+4Eh] [rbp-12h] BYREF +18 char v18[17]; // [rsp+4Fh] [rbp-11h] BYREF +19 +20 std::allocator::allocator(&v17, argv, envp); +21 std::string::string(v10, "Jack", &v17); +22 v3 = (Human *)operator new(0x18uLL); +23 Man::Man(v3, v10, 25LL); +24 v12 = v3; +25 std::string::~string((std::string *)v10); +26 std::allocator::~allocator(&v17); +27 std::allocator::allocator(v18, v10, v4); +28 std::string::string(v11, "Jill", v18); +29 v5 = (Human *)operator new(0x18uLL); +30 Woman::Woman(v5, v11, 21LL); +31 v13 = v5; +32 std::string::~string((std::string *)v11); +33 std::allocator::~allocator(v18); +34 while ( 1 ) +35 { +36 while ( 1 ) +37 { +38 while ( 1 ) +39 { +40 std::operator<<>(&std::cout, "1. use +41 2. after +42 3. free +43 "); +44 std::istream::operator>>(&std::cin, &v16); +45 if ( v16 != 2 ) +46 break; +47 nbytes = atoi(argv[1]); +48 buf = (void *)operator new[](nbytes); +49 v6 = open(argv[2], 0); +50 read(v6, buf, nbytes); +51 v7 = std::operator<<>(&std::cout, "your data is allocated"); +52 std::ostream::operator<<(v7, &std::endl>); +53 } +54 if ( v16 == 3 ) +55 break; +56 if ( v16 == 1 ) +57 { +58 (*(void (__fastcall **)(Human *))(*(_QWORD *)v12 + 8LL))(v12); +59 (*(void (__fastcall **)(Human *))(*(_QWORD *)v13 + 8LL))(v13); +60 } +61 } +62 v8 = v12; +63 if ( v12 ) +64 { +65 Human::~Human(v12); +66 operator delete(v8); +67 } +68 v9 = v13; +69 if ( v13 ) +70 { +71 Human::~Human(v13); +72 operator delete(v9); +73 } +74 } +75 } +``` + +You would respond with: +## Answer +{ + "vulnerabilities": ["use-after-free (62-73)", "buffer-overflow (47-50)"] + "description": "The code contains a classic use-after-free vulnerability. In lines 62-73, the pointers v12 and v13 (which point to objects of type Human) are deleted (freed) using operator delete. If the program's loop (lines 34-74) executes again and the pointers v12 or v13 are accessed without reallocation, it results in undefined behavior due to use-after-free. In lines 47-50, the code reads a size value from argv[1] and uses it directly with operator new[] to allocate a buffer (buf). There are no checks to ensure that nbytes is a reasonable size, potentially leading to a large allocation or integer overflow." +} +{% endif %} + +# Example +Given the following code: +``` +{{ decompilation }} +``` + +You respond with: +""" + +MAN_PAGE = """ +# Task +You are a decompiled C expert that generates summarized man pages for functions. You are given the function the target +function is called in for context. You generate a summarized man page for the target function. +You only do this task for functions that are from libraries or the stdlib. Do not do this on user-defined functions. + +You eventually respond with a valid json. As an example: +## Answer +{ + "function": "printf + "args": ["format (char *)", "arg1 (void)", "arg2 (void)"], + "return": "int", + "description": "The printf() function writes output to stdout using the text formatting instructions contained in the format string. The format string can contain plain text and format specifiers that begin with the % character. Each format specifier is replaced by the value of the corresponding argument in the argument list. The printf() function returns the number of characters written to stdout." +} + + +# Example +Given the following code as context: +``` +void __fastcall gz_error(__int64 a1, int a2, const char *a3) +{ + void *v5; // rcx + __int64 v7; // rbx + __int64 v8; // rax + __int64 v9; // rcx + char *v10; // rax + char *v11; // rcx + const char *v12; // r9 + __int64 v13; // rax + + v5 = *(void **)(a1 + 120); + if ( v5 ) + { + if ( *(_DWORD *)(a1 + 116) != -4 ) + free(v5); + *(_QWORD *)(a1 + 120) = 0LL; + } + if ( a2 && a2 != -5 ) + *(_DWORD *)a1 = 0; + *(_DWORD *)(a1 + 116) = a2; + if ( a3 && a2 != -4 ) + { + v7 = -1LL; + v8 = -1LL; + do + ++v8; + while ( *(_BYTE *)(*(_QWORD *)(a1 + 32) + v8) ); + v9 = -1LL; + do + ++v9; + while ( a3[v9] ); + v10 = (char *)malloc(v8 + 3 + v9); + *(_QWORD *)(a1 + 120) = v10; + v11 = v10; + if ( v10 ) + { + v12 = *(const char **)(a1 + 32); + v13 = -1LL; + while ( v12[++v13] != 0 ) + ; + do + ++v7; + while ( a3[v7] ); + snprintf(v11, v7 + v13 + 3, "%s%s%s", v12, ": ", a3); + } + else + { + *(_DWORD *)(a1 + 116) = -4; + } + } +} +``` + +You focus on the line in the above text: +``` + snprintf(v11, v7 + v13 + 3, "%s%s%s", v12, ": ", a3); +``` + +Focusing on the outermost function call in this line, you respond with: +## Answer +{ + "function": "snprintf", + "args": ["str (char *)", "size (size_t)", "format (const char *)", "..."], + "return": "int", + "description": "The snprintf() function formats and stores a series of characters and values in the array buffer. It is similar to printf(), but with two major differences: it outputs to a buffer rather than stdout, and it takes an additional size parameter specifying the limit of characters to write. The size parameter prevents buffer overflows. It returns the number of characters that would have been written if the buffer was sufficiently large, not counting the terminating null character." +} + +# Example +Given the following code as context: +``` +{{ decompilation }} +``` + +You focus on the line in the above text: +``` +{{ line_text }} +``` + +Focusing on the outermost function call in this line, you respond with: """ \ No newline at end of file diff --git a/dailalib/api/litellm/prompts/prompt.py b/dailalib/api/litellm/prompts/prompt.py index adbb378..1858440 100644 --- a/dailalib/api/litellm/prompts/prompt.py +++ b/dailalib/api/litellm/prompts/prompt.py @@ -7,18 +7,13 @@ from ..litellm_api import LiteLLMAIAPI from .prompt_type import PromptType -from libbs.artifacts import Comment, Function, StackVariable +from libbs.artifacts import Comment, Function, Context from jinja2 import Template, StrictUndefined JSON_REGEX = re.compile(r"\{.*?}", flags=re.DOTALL) class Prompt: - DECOMP_REPLACEMENT_LABEL = "" - SNIPPET_REPLACEMENT_LABEL = "" - SNIPPET_TEXT = f"\n\"\"\"{SNIPPET_REPLACEMENT_LABEL}\"\"\"" - DECOMP_TEXT = f"\n\"\"\"{DECOMP_REPLACEMENT_LABEL}\"\"\"" - def __init__( self, name: str, @@ -28,6 +23,7 @@ def __init__( posttext_response: Optional[str] = None, json_response: bool = True, response_key: str = None, + number_lines: bool = False, ai_api=None, # callback(result, function, ai_api) gui_result_callback: Optional[Callable] = None @@ -40,15 +36,25 @@ def __init__( self._json_response = json_response self._response_key = response_key self._gui_result_callback = gui_result_callback + self._number_lines = number_lines self.desc = desc or name self.ai_api: LiteLLMAIAPI = ai_api + def __str__(self): + return f"" + + def __repr__(self): + return self.__str__() + def _load_template(self, prompt_style: PromptType) -> Template: from . import get_prompt_template template_text = get_prompt_template(self.template_name, prompt_style) + if template_text is None: + raise ValueError(f"Prompt template {self.template_name} not supported in {prompt_style} style!") + return Template(textwrap.dedent(template_text), undefined=StrictUndefined) - def query_model(self, *args, function=None, dec_text=None, use_dec=True, **kwargs): + def query_model(self, *args, context=None, function=None, dec_text=None, use_dec=True, **kwargs): if self.ai_api is None: raise Exception("api must be set before querying!") @@ -58,21 +64,29 @@ def _query_model(ai_api=self.ai_api, function=function, dec_text=dec_text, **_kw return {} ai_api.info(f"Querying {self.name} prompt with function {function}...") + # construct the intial template response = self._pretext_response if self._pretext_response and not self._json_response else "" template = self._load_template(self.ai_api.prompt_style) + # grab decompilation and replace it in the prompt, make sure to fix the decompilation for token max + dec_lines = dec_text.split("\n") query_text = template.render( - decompilation=LiteLLMAIAPI.fit_decompilation_to_token_max(dec_text) if self.ai_api.fit_to_tokens else dec_text, + # decompilation lines of the target function + decompilation=LiteLLMAIAPI.fit_decompilation_to_token_max(dec_text) + if self.ai_api.fit_to_tokens else dec_text, + # line text for emphasis + line_text=dec_lines[context.line_number] if context.line_number is not None else "", + # prompting style (engineering technique) few_shot=bool(self.ai_api.prompt_style == PromptType.FEW_SHOT), ) self.last_rendered_template = query_text - #ai_api.info(f"Prompting using model: {self.ai_api.model}...") - #ai_api.info(f"Prompting with style: {self.ai_api.prompt_style}...") - #ai_api.info(f"Prompting with: {query_text}") + ai_api.info(f"Prompting using model: {self.ai_api.model}...") + ai_api.info(f"Prompting with style: {self.ai_api.prompt_style}...") + ai_api.info(f"Prompting with: {query_text}") ai_api.on_query(self.name, self.ai_api.model, self.ai_api.prompt_style, function, dec_text) response += self.ai_api.query_model(query_text) - #ai_api.info(f"Response received from AI: {response}") + ai_api.info(f"Response received from AI: {response}") default_response = {} if self._json_response else "" if not response: ai_api.warning(f"Response received from AI was empty! AI failed to answer.") @@ -106,17 +120,21 @@ def _query_model(ai_api=self.ai_api, function=function, dec_text=dec_text, **_kw else: ai_api.warning(f"Response recieved from AI, but it was empty! AI failed to answer.") else: - ai_api.info("Reponse received from AI!") + ai_api.info("Response received from AI!") if ai_api.has_decompiler_gui and response: ai_api.info("Updating the decompiler with the AI response...") - self._gui_result_callback(response, function, ai_api) + self._gui_result_callback(response, function, ai_api, context=context) return response - return _query_model(ai_api=self.ai_api, function=function, dec_text=dec_text, use_dec=use_dec) + + return _query_model( + ai_api=self.ai_api, function=function, dec_text=dec_text, use_dec=use_dec, number_lines=self._number_lines, + context=context + ) @staticmethod - def rename_function(result, function, ai_api: "AIAPI"): + def rename_function(result, function, ai_api: "AIAPI", **kwargs): if function.name in result: new_name = result[function.name] else: @@ -126,7 +144,7 @@ def rename_function(result, function, ai_api: "AIAPI"): ai_api._dec_interface.functions[function.addr] = new_func @staticmethod - def rename_variables(result, function, ai_api: "AIAPI"): + def rename_variables(result, function, ai_api: "AIAPI", **kwargs): new_func: Function = function.copy() # clear out changes that are not for variables new_func.name = None @@ -134,7 +152,7 @@ def rename_variables(result, function, ai_api: "AIAPI"): ai_api._dec_interface.rename_local_variables_by_names(function, result) @staticmethod - def comment_function(result, function, ai_api: "AIAPI"): + def comment_function(result, function, ai_api: "AIAPI", **kwargs): curr_cmt_obj = ai_api._dec_interface.comments.get(function.addr, None) curr_cmt = curr_cmt_obj.comment + "\n" if curr_cmt_obj is not None else "" @@ -143,3 +161,66 @@ def comment_function(result, function, ai_api: "AIAPI"): comment=curr_cmt + result, func_addr=function.addr ) + + @staticmethod + def comment_vulnerability(result, function, ai_api: "AIAPI", **kwargs): + rendered = "" + if "vulnerabilities" in result and "description" in result: + rendered += "Vulnerabilities:\n" + for vuln in result["vulnerabilities"]: + rendered += f"- {vuln}\n" + + rendered += "\nVuln Analysis:\n" + rendered += result["description"] + elif isinstance(result, dict): + for key, value in result.items(): + rendered += f"{key}: {value}\n" + else: + rendered = str(result) + + bs_cmt = Comment( + addr=function.addr, + comment=rendered, + func_addr=function.addr + ) + bs_cmt_lines = len(Comment.linewrap_comment(bs_cmt.comment).splitlines()) + + # adjust the lines specified in the comment + # + # find all the line numbers in the comment of form 'lines 23-24' or '-23' or '23-' + nums = set(re.findall("lines (\d+)", rendered)) | set(re.findall("-(\d+)", rendered)) | \ + set(re.findall("(\d+)-", rendered)) + # replace the largest digit numbers first + sorted_nums = sorted(nums, key=lambda x: int(x), reverse=True) + for num in sorted_nums: + _n = int(num, 0) + new_num = str(_n + bs_cmt_lines - 2) + rendered = rendered.replace(num, new_num) + + ai_api._dec_interface.comments[function.addr] = Comment( + addr=function.addr, + comment=rendered, + func_addr=function.addr + ) + + @staticmethod + def comment_man_page(result, function, ai_api: "AIAPI", context=None, **kwargs): + rendered = "\n" + if "function" in result and "args" in result and "return" in result and "description" in result: + rendered += f"Man Page for {result['function']}:\n" + rendered += f"Args: {', '.join(result['args'])}\n" + rendered += f"Return: {result['return']}\n" + rendered += f"Description: {result['description']}\n" + elif isinstance(result, dict): + for key, value in result.items(): + rendered += f"{key}: {value}\n" + else: + rendered = str(result) + + addr = context.addr if isinstance(context, Context) and context.addr is not None else function.addr + ai_api._dec_interface.comments[addr] = Comment( + addr=addr, + comment=rendered, + func_addr=function.addr, + decompiled=True + ) diff --git a/setup.cfg b/setup.cfg index cea76eb..652b4e5 100644 --- a/setup.cfg +++ b/setup.cfg @@ -17,7 +17,7 @@ install_requires = litellm>=1.44.27 tiktoken Jinja2 - libbs>=1.22.0 + libbs>=1.23.0 python_requires = >= 3.10 include_package_data = True From d974bac366d57204e9a3f1f54eb1b2ae934b633a Mon Sep 17 00:00:00 2001 From: mahaloz Date: Mon, 16 Sep 2024 21:20:07 -0700 Subject: [PATCH 2/6] update readme --- README.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/README.md b/README.md index ca689aa..39cde68 100644 --- a/README.md +++ b/README.md @@ -108,6 +108,8 @@ Currently, DAILA supports the following prompts: - Rename variables - Rename function - Identify the source of a function +- Find potential vulnerabilities in a function +- Summarize the man page of a library call ### VarBERT VarBERT is a local BERT model from the S&P 2024 paper [""Len or index or count, anything but v1": Predicting Variable Names in Decompilation Output with Transfer Learning"](). From fc3bcdb43bc13f907f8544eee7387ad9dae5aa77 Mon Sep 17 00:00:00 2001 From: mahaloz Date: Mon, 16 Sep 2024 21:20:57 -0700 Subject: [PATCH 3/6] put back --- dailalib/__init__.py | 1 + 1 file changed, 1 insertion(+) diff --git a/dailalib/__init__.py b/dailalib/__init__.py index 11e6a77..a745818 100644 --- a/dailalib/__init__.py +++ b/dailalib/__init__.py @@ -16,6 +16,7 @@ def create_plugin(*args, **kwargs): f"DAILA/LLM/{prompt_name}": (prompt.desc, getattr(litellm_api, prompt_name)) for prompt_name, prompt in litellm_api.prompts_by_name.items() } + # create context menus for others gui_ctx_menu_actions["DAILA/LLM/update_api_key"] = ("Update API key...", litellm_api.ask_api_key) gui_ctx_menu_actions["DAILA/LLM/update_pmpt_style"] = ("Change prompt style...", litellm_api.ask_prompt_style) gui_ctx_menu_actions["DAILA/LLM/update_model"] = ("Change model...", litellm_api.ask_model) From e856ce8e402011290933c1f5ee73c82945e4e286 Mon Sep 17 00:00:00 2001 From: mahaloz Date: Mon, 16 Sep 2024 21:22:15 -0700 Subject: [PATCH 4/6] remove debug --- dailalib/api/litellm/prompts/prompt.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/dailalib/api/litellm/prompts/prompt.py b/dailalib/api/litellm/prompts/prompt.py index 1858440..9f4d6e3 100644 --- a/dailalib/api/litellm/prompts/prompt.py +++ b/dailalib/api/litellm/prompts/prompt.py @@ -80,13 +80,13 @@ def _query_model(ai_api=self.ai_api, function=function, dec_text=dec_text, **_kw few_shot=bool(self.ai_api.prompt_style == PromptType.FEW_SHOT), ) self.last_rendered_template = query_text - ai_api.info(f"Prompting using model: {self.ai_api.model}...") - ai_api.info(f"Prompting with style: {self.ai_api.prompt_style}...") - ai_api.info(f"Prompting with: {query_text}") + #ai_api.info(f"Prompting using model: {self.ai_api.model}...") + #ai_api.info(f"Prompting with style: {self.ai_api.prompt_style}...") + #ai_api.info(f"Prompting with: {query_text}") ai_api.on_query(self.name, self.ai_api.model, self.ai_api.prompt_style, function, dec_text) response += self.ai_api.query_model(query_text) - ai_api.info(f"Response received from AI: {response}") + #ai_api.info(f"Response received from AI: {response}") default_response = {} if self._json_response else "" if not response: ai_api.warning(f"Response received from AI was empty! AI failed to answer.") From d35894683f0c26297dfe5b1772d4e438a39517f4 Mon Sep 17 00:00:00 2001 From: mahaloz Date: Mon, 16 Sep 2024 21:22:56 -0700 Subject: [PATCH 5/6] Fix version --- setup.cfg | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.cfg b/setup.cfg index 652b4e5..9e0589a 100644 --- a/setup.cfg +++ b/setup.cfg @@ -17,7 +17,7 @@ install_requires = litellm>=1.44.27 tiktoken Jinja2 - libbs>=1.23.0 + libbs>=1.23.1 python_requires = >= 3.10 include_package_data = True From 8b45849c548c94f2a8607aa78ac95d6314bce9b3 Mon Sep 17 00:00:00 2001 From: mahaloz Date: Mon, 16 Sep 2024 21:25:00 -0700 Subject: [PATCH 6/6] Hide things --- dailalib/__init__.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/dailalib/__init__.py b/dailalib/__init__.py index a745818..ae20117 100644 --- a/dailalib/__init__.py +++ b/dailalib/__init__.py @@ -17,9 +17,9 @@ def create_plugin(*args, **kwargs): for prompt_name, prompt in litellm_api.prompts_by_name.items() } # create context menus for others - gui_ctx_menu_actions["DAILA/LLM/update_api_key"] = ("Update API key...", litellm_api.ask_api_key) - gui_ctx_menu_actions["DAILA/LLM/update_pmpt_style"] = ("Change prompt style...", litellm_api.ask_prompt_style) - gui_ctx_menu_actions["DAILA/LLM/update_model"] = ("Change model...", litellm_api.ask_model) + gui_ctx_menu_actions["DAILA/LLM/Settings/update_api_key"] = ("Update API key...", litellm_api.ask_api_key) + gui_ctx_menu_actions["DAILA/LLM/Settings/update_pmpt_style"] = ("Change prompt style...", litellm_api.ask_prompt_style) + gui_ctx_menu_actions["DAILA/LLM/Settings/update_model"] = ("Change model...", litellm_api.ask_model) # # VarModel API (local variable renaming)