diff --git a/README.md b/README.md index 48e12621..e86ce4a5 100644 --- a/README.md +++ b/README.md @@ -38,6 +38,7 @@ For building binary if you wish to build from source, then `cargo` is required. build = "make", -- build = "powershell -ExecutionPolicy Bypass -File Build.ps1 -BuildFromSource false" -- for windows dependencies = { + "nvim-treesitter/nvim-treesitter", "stevearc/dressing.nvim", "nvim-lua/plenary.nvim", "MunifTanjim/nui.nvim", diff --git a/crates/avante-templates/src/lib.rs b/crates/avante-templates/src/lib.rs index 81703fef..0f05669b 100644 --- a/crates/avante-templates/src/lib.rs +++ b/crates/avante-templates/src/lib.rs @@ -21,6 +21,7 @@ struct TemplateContext { ask: bool, question: String, code_lang: String, + filepath: String, file_content: String, selected_code: Option, project_context: Option, @@ -45,6 +46,7 @@ fn render(state: &State, template: &str, context: TemplateContext) -> LuaResult< ask => context.ask, question => context.question, code_lang => context.code_lang, + filepath => context.filepath, file_content => context.file_content, selected_code => context.selected_code, project_context => context.project_context, diff --git a/lua/avante/config.lua b/lua/avante/config.lua index ebbc31ea..bef41378 100644 --- a/lua/avante/config.lua +++ b/lua/avante/config.lua @@ -92,6 +92,7 @@ You are an excellent programming expert. ---3. auto_set_highlight_group : Whether to automatically set the highlight group for the current line. Default to true. ---4. support_paste_from_clipboard : Whether to support pasting image from clipboard. This will be determined automatically based whether img-clip is available or not. behaviour = { + enable_project_context = false, auto_suggestions = false, -- Experimental stage auto_set_highlight_group = true, auto_set_keymaps = true, diff --git a/lua/avante/llm.lua b/lua/avante/llm.lua index 99bc70b3..a14576aa 100644 --- a/lua/avante/llm.lua +++ b/lua/avante/llm.lua @@ -59,11 +59,14 @@ M.stream = function(opts) Path.prompts.initialize(Path.prompts.get(opts.bufnr)) + local filepath = Utils.relative_path(api.nvim_buf_get_name(opts.bufnr)) + local template_opts = { use_xml_format = Provider.use_xml_format, ask = opts.ask, -- TODO: add mode without ask instruction question = original_instructions, code_lang = opts.code_lang, + filepath = filepath, file_content = opts.file_content, selected_code = opts.selected_code, project_context = opts.project_context, diff --git a/lua/avante/path.lua b/lua/avante/path.lua index 82c4f7d5..46f20c55 100644 --- a/lua/avante/path.lua +++ b/lua/avante/path.lua @@ -28,17 +28,17 @@ end H.get_mode_file = function(mode) return string.format("custom.%s.avanterules", mode) end -- History path -local M = {} +local History = {} -- Returns the Path to the chat history file for the given buffer. ---@param bufnr integer ---@return Path -M.get = function(bufnr) return Path:new(Config.history.storage_path):joinpath(H.filename(bufnr)) end +History.get = function(bufnr) return Path:new(Config.history.storage_path):joinpath(H.filename(bufnr)) end -- Loads the chat history for the given buffer. ---@param bufnr integer -M.load = function(bufnr) - local history_file = M.get(bufnr) +History.load = function(bufnr) + local history_file = History.get(bufnr) if history_file:exists() then local content = history_file:read() return content ~= nil and vim.json.decode(content) or {} @@ -49,29 +49,29 @@ end -- Saves the chat history for the given buffer. ---@param bufnr integer ---@param history table -M.save = function(bufnr, history) - local history_file = M.get(bufnr) +History.save = function(bufnr, history) + local history_file = History.get(bufnr) history_file:write(vim.json.encode(history), "w") end -P.history = M +P.history = History -- Prompt path -local N = {} +local Prompt = {} ---@class AvanteTemplates ---@field initialize fun(directory: string): nil ---@field render fun(template: string, context: TemplateOptions): string local templates = nil -N.templates = { planning = nil, editing = nil, suggesting = nil } +Prompt.templates = { planning = nil, editing = nil, suggesting = nil } -- Creates a directory in the cache path for the given buffer and copies the custom prompts to it. -- We need to do this beacuse the prompt template engine requires a given directory to load all required files. -- PERF: Hmm instead of copy to cache, we can also load in globals context, but it requires some work on bindings. (eh maybe?) ---@param bufnr number ---@return string the resulted cache_directory to be loaded with avante_templates -N.get = function(bufnr) +Prompt.get = function(bufnr) if not P.available() then error("Make sure to build avante (missing avante_templates)", 2) end -- get root directory of given bufnr @@ -85,19 +85,19 @@ N.get = function(bufnr) local scanner = Scan.scan_dir(directory:absolute(), { depth = 1, add_dirs = true }) for _, entry in ipairs(scanner) do local file = Path:new(entry) - if entry:find("planning") and N.templates.planning == nil then - N.templates.planning = file:read() - elseif entry:find("editing") and N.templates.editing == nil then - N.templates.editing = file:read() - elseif entry:find("suggesting") and N.templates.suggesting == nil then - N.templates.suggesting = file:read() + if entry:find("planning") and Prompt.templates.planning == nil then + Prompt.templates.planning = file:read() + elseif entry:find("editing") and Prompt.templates.editing == nil then + Prompt.templates.editing = file:read() + elseif entry:find("suggesting") and Prompt.templates.suggesting == nil then + Prompt.templates.suggesting = file:read() end end Path:new(debug.getinfo(1).source:match("@?(.*/)"):gsub("/lua/avante/path.lua$", "") .. "templates") :copy({ destination = cache_prompt_dir, recursive = true }) - vim.iter(N.templates):filter(function(_, v) return v ~= nil end):each(function(k, v) + vim.iter(Prompt.templates):filter(function(_, v) return v ~= nil end):each(function(k, v) local f = cache_prompt_dir:joinpath(H.get_mode_file(k)) f:write(v, "w") end) @@ -106,22 +106,53 @@ N.get = function(bufnr) end ---@param mode LlmMode -N.get_file = function(mode) - if N.templates[mode] ~= nil then return H.get_mode_file(mode) end +Prompt.get_file = function(mode) + if Prompt.templates[mode] ~= nil then return H.get_mode_file(mode) end return string.format("%s.avanterules", mode) end ---@param path string ---@param opts TemplateOptions -N.render_file = function(path, opts) return templates.render(path, opts) end +Prompt.render_file = function(path, opts) return templates.render(path, opts) end ---@param mode LlmMode ---@param opts TemplateOptions -N.render_mode = function(mode, opts) return templates.render(N.get_file(mode), opts) end +Prompt.render_mode = function(mode, opts) return templates.render(Prompt.get_file(mode), opts) end -N.initialize = function(directory) templates.initialize(directory) end +Prompt.initialize = function(directory) templates.initialize(directory) end -P.prompts = N +P.prompts = Prompt + +local RepoMap = {} + +-- Get a chat history file name given a buffer +---@param project_root string +---@param ext string +---@return string +RepoMap.filename = function(project_root, ext) + -- Replace path separators with double underscores + local path_with_separators = fn.substitute(project_root, "/", "__", "g") + -- Replace other non-alphanumeric characters with single underscores + return fn.substitute(path_with_separators, "[^A-Za-z0-9._]", "_", "g") .. "." .. ext .. ".repo_map.json" +end + +RepoMap.get = function(project_root, ext) return Path:new(P.data_path):joinpath(RepoMap.filename(project_root, ext)) end + +RepoMap.save = function(project_root, ext, data) + local file = RepoMap.get(project_root, ext) + file:write(vim.json.encode(data), "w") +end + +RepoMap.load = function(project_root, ext) + local file = RepoMap.get(project_root, ext) + if file:exists() then + local content = file:read() + return content ~= nil and vim.json.decode(content) or {} + end + return nil +end + +P.repo_map = RepoMap P.setup = function() local history_path = Path:new(Config.history.storage_path) @@ -132,6 +163,10 @@ P.setup = function() if not cache_path:exists() then cache_path:mkdir({ parents = true }) end P.cache_path = cache_path + local data_path = Path:new(vim.fn.stdpath("data") .. "/avante") + if not data_path:exists() then data_path:mkdir({ parents = true }) end + P.data_path = data_path + vim.defer_fn(function() local ok, module = pcall(require, "avante_templates") ---@cast module AvanteTemplates diff --git a/lua/avante/selection.lua b/lua/avante/selection.lua index 272cad3a..eb855eb4 100644 --- a/lua/avante/selection.lua +++ b/lua/avante/selection.lua @@ -392,10 +392,12 @@ function Selection:create_editing_input() end local filetype = api.nvim_get_option_value("filetype", { buf = code_bufnr }) + local project_context = Utils.repo_map.get_repo_map() Llm.stream({ bufnr = code_bufnr, ask = true, + project_context = vim.json.encode(project_context), file_content = code_content, code_lang = filetype, selected_code = self.selection.content, diff --git a/lua/avante/sidebar.lua b/lua/avante/sidebar.lua index a47c2235..5304def2 100644 --- a/lua/avante/sidebar.lua +++ b/lua/avante/sidebar.lua @@ -1270,9 +1270,12 @@ function Sidebar:create_input(opts) Path.history.save(self.code.bufnr, chat_history) end + local project_context = Utils.repo_map.get_repo_map() + Llm.stream({ bufnr = self.code.bufnr, ask = opts.ask, + project_context = vim.json.encode(project_context), file_content = content_with_line_numbers, code_lang = filetype, selected_code = selected_code_content_with_line_numbers, diff --git a/lua/avante/suggestion.lua b/lua/avante/suggestion.lua index d20d74a9..06ca32ed 100644 --- a/lua/avante/suggestion.lua +++ b/lua/avante/suggestion.lua @@ -65,9 +65,13 @@ function Suggestion:suggest() local full_response = "" + local project_context = Utils.repo_map.get_repo_map() + print("project_context", vim.inspect(project_context)) + Llm.stream({ bufnr = bufnr, ask = true, + project_context = vim.json.encode(project_context), file_content = code_content, code_lang = filetype, instructions = vim.json.encode(doc), @@ -79,24 +83,26 @@ function Suggestion:suggest() return end Utils.debug("full_response: " .. vim.inspect(full_response)) - local cursor_row, cursor_col = Utils.get_cursor_pos() - if cursor_row ~= doc.position.row or cursor_col ~= doc.position.col then return end - local ok, suggestions = pcall(vim.json.decode, full_response) - if not ok then - Utils.error("Error while decoding suggestions: " .. full_response, { once = true, title = "Avante" }) - return - end - if not suggestions then - Utils.info("No suggestions found", { once = true, title = "Avante" }) - return - end - suggestions = vim - .iter(suggestions) - :map(function(s) return { row = s.row, col = s.col, content = Utils.trim_all_line_numbers(s.content) } end) - :totable() - ctx.suggestions = suggestions - ctx.current_suggestion_idx = 1 - self:show() + vim.schedule(function() + local cursor_row, cursor_col = Utils.get_cursor_pos() + if cursor_row ~= doc.position.row or cursor_col ~= doc.position.col then return end + local ok, suggestions = pcall(vim.json.decode, full_response) + if not ok then + Utils.error("Error while decoding suggestions: " .. full_response, { once = true, title = "Avante" }) + return + end + if not suggestions then + Utils.info("No suggestions found", { once = true, title = "Avante" }) + return + end + suggestions = vim + .iter(suggestions) + :map(function(s) return { row = s.row, col = s.col, content = Utils.trim_all_line_numbers(s.content) } end) + :totable() + ctx.suggestions = suggestions + ctx.current_suggestion_idx = 1 + self:show() + end) end, }) end diff --git a/lua/avante/templates/_context.avanterules b/lua/avante/templates/_context.avanterules index 6f60bdb4..7fe42d9b 100644 --- a/lua/avante/templates/_context.avanterules +++ b/lua/avante/templates/_context.avanterules @@ -1,4 +1,6 @@ {%- if use_xml_format -%} +{{filepath}} + {%- if selected_code -%} ```{{code_lang}} @@ -19,6 +21,8 @@ {%- endif %} {% else %} +FILEPATH: {{filepath}} + {%- if selected_code -%} CONTEXT: ```{{code_lang}} diff --git a/lua/avante/utils/file.lua b/lua/avante/utils/file.lua new file mode 100644 index 00000000..a043df9a --- /dev/null +++ b/lua/avante/utils/file.lua @@ -0,0 +1,36 @@ +local LRUCache = require("avante.utils.lru_cache") + +---@class avante.utils.file +local M = {} + +local api = vim.api +local fn = vim.fn + +local _file_content_lru_cache = LRUCache:new(60) + +api.nvim_create_autocmd("BufWritePost", { + callback = function() + local filepath = api.nvim_buf_get_name(0) + local keys = _file_content_lru_cache:keys() + if vim.tbl_contains(keys, filepath) then + local content = table.concat(api.nvim_buf_get_lines(0, 0, -1, false), "\n") + _file_content_lru_cache:set(filepath, content) + end + end, +}) + +function M.read_content(filepath) + local cached_content = _file_content_lru_cache:get(filepath) + if cached_content then return cached_content end + + local content = fn.readfile(filepath) + if content then + content = table.concat(content, "\n") + _file_content_lru_cache:set(filepath, content) + return content + end + + return nil +end + +return M diff --git a/lua/avante/utils/init.lua b/lua/avante/utils/init.lua index 6a42b842..b3389c90 100644 --- a/lua/avante/utils/init.lua +++ b/lua/avante/utils/init.lua @@ -5,6 +5,7 @@ local lsp = vim.lsp ---@class avante.utils: LazyUtilCore ---@field tokens avante.utils.tokens ---@field root avante.utils.root +---@field repo_map avante.utils.repo_map local M = {} setmetatable(M, { @@ -444,7 +445,7 @@ function M.get_indentation(code) return code:match("^%s*") or "" end --- remove indentation from code: spaces or tabs function M.remove_indentation(code) return code:gsub("^%s*", "") end -local function relative_path(absolute) +function M.relative_path(absolute) local relative = fn.fnamemodify(absolute, ":.") if string.sub(relative, 0, 1) == "/" then return fn.fnamemodify(absolute, ":t") end return relative @@ -462,7 +463,7 @@ function M.get_doc() local doc = { uri = params.textDocument.uri, version = api.nvim_buf_get_var(0, "changedtick"), - relativePath = relative_path(absolute), + relativePath = M.relative_path(absolute), insertSpaces = vim.o.expandtab, tabSize = fn.shiftwidth(), indentSize = fn.shiftwidth(), @@ -520,4 +521,99 @@ function M.winline(winid) return line end +function M.get_project_root() return M.root.get() end + +function M.is_same_file_ext(target_ext, filepath) + local ext = fn.fnamemodify(filepath, ":e") + if target_ext == "tsx" and ext == "ts" then return true end + if target_ext == "jsx" and ext == "js" then return true end + return ext == target_ext +end + +-- Get recent filepaths in the same project and same file ext +function M.get_recent_filepaths(limit, filenames) + local project_root = M.get_project_root() + local current_ext = fn.expand("%:e") + local oldfiles = vim.v.oldfiles + local recent_files = {} + + for _, file in ipairs(oldfiles) do + if vim.startswith(file, project_root) and M.is_same_file_ext(current_ext, file) then + if filenames and #filenames > 0 then + for _, filename in ipairs(filenames) do + if file:find(filename) then table.insert(recent_files, file) end + end + else + table.insert(recent_files, file) + end + if #recent_files >= (limit or 10) then break end + end + end + + return recent_files +end + +local function pattern_to_lua(pattern) + local lua_pattern = pattern:gsub("[%(%)%.%%%+%-%*%?%[%]%^%$]", "%%%1") + lua_pattern = lua_pattern:gsub("%*%*/", ".-/") + lua_pattern = lua_pattern:gsub("%*", "[^/]*") + lua_pattern = lua_pattern:gsub("%?", ".") + if lua_pattern:sub(-1) == "/" then lua_pattern = lua_pattern .. ".*" end + return lua_pattern +end + +function M.parse_gitignore(gitignore_path) + local ignore_patterns = { ".git", ".worktree", "__pycache__", "node_modules" } + local negate_patterns = {} + local file = io.open(gitignore_path, "r") + if not file then return ignore_patterns, negate_patterns end + + for line in file:lines() do + if line:match("%S") and not line:match("^#") then + local trimmed_line = line:match("^%s*(.-)%s*$") + if trimmed_line:sub(1, 1) == "!" then + table.insert(negate_patterns, pattern_to_lua(trimmed_line:sub(2))) + else + table.insert(ignore_patterns, pattern_to_lua(trimmed_line)) + end + end + end + + file:close() + return ignore_patterns, negate_patterns +end + +local function is_ignored(file, ignore_patterns, negate_patterns) + for _, pattern in ipairs(negate_patterns) do + if file:match(pattern) then return false end + end + for _, pattern in ipairs(ignore_patterns) do + if file:match(pattern) then return true end + end + return false +end + +function M.scan_directory(directory, ignore_patterns, negate_patterns) + local files = {} + local handle = vim.loop.fs_scandir(directory) + + if not handle then return files end + + while true do + local name, type = vim.loop.fs_scandir_next(handle) + if not name then break end + + local full_path = directory .. "/" .. name + if type == "directory" then + vim.list_extend(files, M.scan_directory(full_path, ignore_patterns, negate_patterns)) + elseif type == "file" then + if not is_ignored(full_path, ignore_patterns, negate_patterns) then table.insert(files, full_path) end + end + end + + return files +end + +function M.is_first_letter_uppercase(str) return string.match(str, "^[A-Z]") ~= nil end + return M diff --git a/lua/avante/utils/lru_cache.lua b/lua/avante/utils/lru_cache.lua new file mode 100644 index 00000000..a6f5bb80 --- /dev/null +++ b/lua/avante/utils/lru_cache.lua @@ -0,0 +1,115 @@ +local LRUCache = {} +LRUCache.__index = LRUCache + +function LRUCache:new(capacity) + return setmetatable({ + capacity = capacity, + cache = {}, + head = nil, + tail = nil, + size = 0, + }, LRUCache) +end + +-- Internal function: Move node to head (indicating most recently used) +function LRUCache:_move_to_head(node) + if self.head == node then return end + + -- Disconnect the node + if node.prev then node.prev.next = node.next end + + if node.next then node.next.prev = node.prev end + + if self.tail == node then self.tail = node.prev end + + -- Insert the node at the head + node.next = self.head + node.prev = nil + + if self.head then self.head.prev = node end + self.head = node + + if not self.tail then self.tail = node end +end + +-- Get value from cache +function LRUCache:get(key) + local node = self.cache[key] + if not node then return nil end + + self:_move_to_head(node) + + return node.value +end + +-- Set value in cache +function LRUCache:set(key, value) + local node = self.cache[key] + + if node then + node.value = value + self:_move_to_head(node) + else + node = { key = key, value = value } + self.cache[key] = node + self.size = self.size + 1 + + self:_move_to_head(node) + + if self.size > self.capacity then + local tail_key = self.tail.key + self.tail = self.tail.prev + if self.tail then self.tail.next = nil end + self.cache[tail_key] = nil + self.size = self.size - 1 + end + end +end + +-- Remove specified cache entry +function LRUCache:remove(key) + local node = self.cache[key] + if not node then return end + + if node.prev then + node.prev.next = node.next + else + self.head = node.next + end + + if node.next then + node.next.prev = node.prev + else + self.tail = node.prev + end + + self.cache[key] = nil + self.size = self.size - 1 +end + +-- Get current size of cache +function LRUCache:get_size() return self.size end + +-- Get capacity of cache +function LRUCache:get_capacity() return self.capacity end + +-- Print current cache contents (for debugging) +function LRUCache:print_cache() + local node = self.head + while node do + print(node.key, node.value) + node = node.next + end +end + +function LRUCache:keys() + local keys = {} + local node = self.head + while node do + table.insert(keys, node.key) + node = node.next + end + return keys +end + +return LRUCache diff --git a/lua/avante/utils/repo_map.lua b/lua/avante/utils/repo_map.lua new file mode 100644 index 00000000..a9b940b0 --- /dev/null +++ b/lua/avante/utils/repo_map.lua @@ -0,0 +1,731 @@ +local parsers = require("nvim-treesitter.parsers") +local Config = require("avante.config") + +local get_node_text = vim.treesitter.get_node_text + +---@class avante.utils.repo_map +local RepoMap = {} + +local dependencies_queries = { + lua = [[ + (function_call + name: (identifier) @function_name + arguments: (arguments + (string) @required_file)) + ]], + + python = [[ + (import_from_statement + module_name: (dotted_name) @import_module) + (import_statement + (dotted_name) @import_module) + ]], + + javascript = [[ + (import_statement + source: (string) @import_module) + (call_expression + function: (identifier) @function_name + arguments: (arguments + (string) @required_file)) + ]], + + typescript = [[ + (import_statement + source: (string) @import_module) + (call_expression + function: (identifier) @function_name + arguments: (arguments + (string) @required_file)) + ]], + + go = [[ + (import_spec + path: (interpreted_string_literal) @import_module) + ]], + + rust = [[ + (use_declaration + (scoped_identifier) @import_module) + (use_declaration + (identifier) @import_module) + ]], + + c = [[ + (preproc_include + (string_literal) @import_module) + (preproc_include + (system_lib_string) @import_module) + ]], + + cpp = [[ + (preproc_include + (string_literal) @import_module) + (preproc_include + (system_lib_string) @import_module) + ]], +} + +local definitions_queries = { + python = [[ + ;; Capture top-level functions, class, and method definitions + (module + (expression_statement + (assignment) @assignment + ) + ) + (module + (function_definition) @function + ) + (module + (class_definition + body: (block + (expression_statement + (assignment) @class_assignment + ) + ) + ) + ) + (module + (class_definition + body: (block + (function_definition) @method + ) + ) + ) + ]], + javascript = [[ + ;; Capture exported functions, arrow functions, variables, classes, and method definitions + (export_statement + declaration: (lexical_declaration + (variable_declarator) @variable + ) + ) + (export_statement + declaration: (function_declaration) @function + ) + (export_statement + declaration: (class_declaration + body: (class_body + (field_definition) @class_variable + ) + ) + ) + (export_statement + declaration: (class_declaration + body: (class_body + (method_definition) @method + ) + ) + ) + ]], + typescript = [[ + ;; Capture exported functions, arrow functions, variables, classes, and method definitions + (export_statement + declaration: (lexical_declaration + (variable_declarator) @variable + ) + ) + (export_statement + declaration: (function_declaration) @function + ) + (export_statement + declaration: (class_declaration + body: (class_body + (public_field_definition) @class_variable + ) + ) + ) + (interface_declaration + body: (interface_body + (property_signature) @class_variable + ) + ) + (type_alias_declaration + value: (object_type + (property_signature) @class_variable + ) + ) + (export_statement + declaration: (class_declaration + body: (class_body + (method_definition) @method + ) + ) + ) + ]], + rust = [[ + ;; Capture public functions, structs, methods, and variable definitions + (function_item) @function + (impl_item + body: (declaration_list + (function_item) @method + ) + ) + (struct_item + body: (field_declaration_list + (field_declaration) @class_variable + ) + ) + (enum_item + body: (enum_variant_list + (enum_variant) @enum_item + ) + ) + (const_item) @variable + ]], + go = [[ + ;; Capture top-level functions and struct definitions + (var_declaration + (var_spec) @variable + ) + (const_declaration + (const_spec) @variable + ) + (function_declaration) @function + (type_declaration + (type_spec (struct_type)) @class + ) + (type_declaration + (type_spec + (struct_type + (field_declaration_list + (field_declaration) @class_variable))) + ) + (method_declaration) @method + ]], + c = [[ + ;; Capture extern functions, variables, public classes, and methods + (function_definition + (storage_class_specifier) @extern + ) @function + (class_specifier + (public) @class + (function_definition) @method + ) @class + (declaration + (storage_class_specifier) @extern + ) @variable + ]], + cpp = [[ + ;; Capture extern functions, variables, public classes, and methods + (function_definition + (storage_class_specifier) @extern + ) @function + (class_specifier + (public) @class + (function_definition) @method + ) @class + (declaration + (storage_class_specifier) @extern + ) @variable + ]], + lua = [[ + ;; Capture function and method definitions + (variable_list) @variable + (function_declaration) @function + ]], + ruby = [[ + ;; Capture top-level methods, class definitions, and methods within classes + (method) @function + (assignment) @assignment + (class + body: (body_statement + (assignment) @class_assignment + (method) @method + ) + ) + ]], +} + +local queries_filetype_map = { + ["javascriptreact"] = "javascript", + ["typescriptreact"] = "typescript", +} + +local function get_query(queries, filetype) + filetype = queries_filetype_map[filetype] or filetype + return queries[filetype] +end + +local function get_ts_lang(bufnr) + local lang = parsers.get_buf_lang(bufnr) + return lang +end + +function RepoMap.get_parser(bufnr) + local lang = get_ts_lang(bufnr) + if not lang then return end + local parser = parsers.get_parser(bufnr, lang) + return parser, lang +end + +function RepoMap.extract_dependencies(bufnr) + local parser, lang = RepoMap.get_parser(bufnr) + if not lang or not parser or not dependencies_queries[lang] then + print("No parser or query available for this buffer's language: " .. (lang or "unknown")) + return {} + end + + local dependencies = {} + local tree = parser:parse()[1] + local root = tree:root() + local filetype = vim.api.nvim_get_option_value("filetype", { buf = bufnr }) + + local query = get_query(dependencies_queries, filetype) + if not query then return dependencies end + + local query_obj = vim.treesitter.query.parse(lang, query) + + for _, node, _ in query_obj:iter_captures(root, bufnr, 0, -1) do + -- local name = query.captures[id] + local required_file = vim.treesitter.get_node_text(node, bufnr):gsub('"', ""):gsub("'", "") + table.insert(dependencies, required_file) + end + + return dependencies +end + +function RepoMap.get_filetype_by_filepath(filepath) return vim.filetype.match({ filename = filepath }) end + +function RepoMap.parse_file(filepath) + local File = require("avante.utils.file") + local source = File.read_content(filepath) + + local filetype = RepoMap.get_filetype_by_filepath(filepath) + local lang = parsers.ft_to_lang(filetype) + if lang then + local ok, parser = pcall(vim.treesitter.get_string_parser, source, lang) + if ok then + local tree = parser:parse()[1] + local node = tree:root() + return { node = node, source = source } + else + print("parser error", parser) + end + end +end + +local function get_closest_parent_name(node, source) + local parent = node:parent() + while parent do + local name = parent:field("name")[1] + if name then return get_node_text(name, source) end + parent = parent:parent() + end + return "" +end + +local function find_parent_by_type(node, type) + local parent = node:parent() + while parent do + if parent:type() == type then return parent end + parent = parent:parent() + end + return nil +end + +local function find_child_by_type(node, type) + for child in node:iter_children() do + if child:type() == type then return child end + local res = find_child_by_type(child, type) + if res then return res end + end + return nil +end + +local function get_node_type(node, source) + local node_type + local predefined_type_node = find_child_by_type(node, "predefined_type") + if predefined_type_node then + node_type = get_node_text(predefined_type_node, source) + else + local value_type_node = node:field("type")[1] + node_type = value_type_node and get_node_text(value_type_node, source) or "" + end + return node_type +end + +-- Function to extract definitions from the file +function RepoMap.extract_definitions(filepath) + local Utils = require("avante.utils") + + local filetype = RepoMap.get_filetype_by_filepath(filepath) + + if not filetype then return {} end + + -- Get the corresponding query for the detected language + local query = get_query(definitions_queries, filetype) + if not query then return {} end + + local parsed = RepoMap.parse_file(filepath) + if not parsed then return {} end + + -- Get the current buffer's syntax tree + local root = parsed.node + + local lang = parsers.ft_to_lang(filetype) + + -- Parse the query + local query_obj = vim.treesitter.query.parse(lang, query) + + -- Store captured results + local definitions = {} + + local class_def_map = {} + local enum_def_map = {} + + local function get_class_def(name) + local def = class_def_map[name] + if def == nil then + def = { + type = "class", + name = name, + methods = {}, + properties = {}, + } + class_def_map[name] = def + end + return def + end + + local function get_enum_def(name) + local def = enum_def_map[name] + if def == nil then + def = { + type = "enum", + name = name, + items = {}, + } + enum_def_map[name] = def + end + return def + end + + for _, captures, _ in query_obj:iter_matches(root, parsed.source) do + for id, node in pairs(captures) do + local type = query_obj.captures[id] + local name_node = node:field("name")[1] + local name = name_node and get_node_text(name_node, parsed.source) or "" + + if type == "class" then + if name ~= "" then get_class_def(name) end + elseif type == "enum_item" then + local enum_name = get_closest_parent_name(node, parsed.source) + if enum_name and filetype == "go" and not Utils.is_first_letter_uppercase(enum_name) then goto continue end + local enum_def = get_enum_def(enum_name) + local enum_type_node = find_child_by_type(node, "type_identifier") + local enum_type = enum_type_node and get_node_text(enum_type_node, parsed.source) or "" + table.insert(enum_def.items, { + name = name, + type = enum_type, + }) + elseif type == "method" then + local params_node = node:field("parameters")[1] + local params = params_node and get_node_text(params_node, parsed.source) or "()" + local return_type_node = node:field("return_type")[1] or node:field("result")[1] + local return_type = return_type_node and get_node_text(return_type_node, parsed.source) or "void" + + local class_name + local impl_item_node = find_parent_by_type(node, "impl_item") + local receiver_node = node:field("receiver")[1] + if impl_item_node then + local impl_type_node = impl_item_node:field("type")[1] + class_name = impl_type_node and get_node_text(impl_type_node, parsed.source) or "" + elseif receiver_node then + local type_identifier_node = find_child_by_type(receiver_node, "type_identifier") + class_name = type_identifier_node and get_node_text(type_identifier_node, parsed.source) or "" + else + class_name = get_closest_parent_name(node, parsed.source) + end + if class_name and filetype == "go" and not Utils.is_first_letter_uppercase(class_name) then goto continue end + local class_def = get_class_def(class_name) + + local accessibility_modifier_node = find_child_by_type(node, "accessibility_modifier") + local accessibility_modifier = accessibility_modifier_node + and get_node_text(accessibility_modifier_node, parsed.source) + or "" + + table.insert(class_def.methods, { + type = "function", + name = name, + params = params, + return_type = return_type, + accessibility_modifier = accessibility_modifier, + }) + elseif type == "class_assignment" then + local left_node = node:field("left")[1] + local left = left_node and get_node_text(left_node, parsed.source) or "" + + local value_type = get_node_type(node, parsed.source) + + local class_name = get_closest_parent_name(node, parsed.source) + if class_name and filetype == "go" and not Utils.is_first_letter_uppercase(class_name) then goto continue end + + local class_def = get_class_def(class_name) + + table.insert(class_def.properties, { + type = "variable", + name = left, + value_type = value_type, + }) + elseif type == "class_variable" then + local value_type = get_node_type(node, parsed.source) + + local class_name = get_closest_parent_name(node, parsed.source) + if class_name and filetype == "go" and not Utils.is_first_letter_uppercase(class_name) then goto continue end + + local class_def = get_class_def(class_name) + + table.insert(class_def.properties, { + type = "variable", + name = name, + value_type = value_type, + }) + elseif type == "function" or type == "arrow_function" then + if name and filetype == "go" and not Utils.is_first_letter_uppercase(name) then goto continue end + local impl_item_node = find_parent_by_type(node, "impl_item") + if impl_item_node then goto continue end + local function_node = find_parent_by_type(node, "function_declaration") + or find_parent_by_type(node, "function_definition") + if function_node then goto continue end + -- Extract function parameters and return type + local params_node = node:field("parameters")[1] + local params = params_node and get_node_text(params_node, parsed.source) or "()" + local return_type_node = node:field("return_type")[1] or node:field("result")[1] + local return_type = return_type_node and get_node_text(return_type_node, parsed.source) or "void" + + local accessibility_modifier_node = find_child_by_type(node, "accessibility_modifier") + local accessibility_modifier = accessibility_modifier_node + and get_node_text(accessibility_modifier_node, parsed.source) + or "" + + local def = { + type = "function", + name = name, + params = params, + return_type = return_type, + accessibility_modifier = accessibility_modifier, + } + table.insert(definitions, def) + elseif type == "assignment" then + local impl_item_node = find_parent_by_type(node, "impl_item") + or find_parent_by_type(node, "class_declaration") + or find_parent_by_type(node, "class_definition") + if impl_item_node then goto continue end + local function_node = find_parent_by_type(node, "function_declaration") + or find_parent_by_type(node, "function_definition") + if function_node then goto continue end + + local left_node = node:field("left")[1] + local left = left_node and get_node_text(left_node, parsed.source) or "" + + if left and filetype == "go" and not Utils.is_first_letter_uppercase(left) then goto continue end + + local value_type = get_node_type(node, parsed.source) + + local def = { + type = "variable", + name = left, + value_type = value_type, + } + table.insert(definitions, def) + elseif type == "variable" then + local impl_item_node = find_parent_by_type(node, "impl_item") + or find_parent_by_type(node, "class_declaration") + or find_parent_by_type(node, "class_definition") + if impl_item_node then goto continue end + local function_node = find_parent_by_type(node, "function_declaration") + or find_parent_by_type(node, "function_definition") + if function_node then goto continue end + + local value_type = get_node_type(node, parsed.source) + + if name and filetype == "go" and not Utils.is_first_letter_uppercase(name) then goto continue end + + local def = { type = "variable", name = name, value_type = value_type } + table.insert(definitions, def) + end + ::continue:: + end + end + + for _, def in pairs(class_def_map) do + table.insert(definitions, def) + end + + for _, def in pairs(enum_def_map) do + table.insert(definitions, def) + end + + return definitions +end + +local function stringify_function(def) + local res = "func " .. def.name .. def.params .. ": " .. def.return_type .. ";" + if def.accessibility_modifier and def.accessibility_modifier ~= "" then + res = def.accessibility_modifier .. " " .. res + end + return res +end + +local function stringify_variable(def) + local res = "var " .. def.name + if def.value_type and def.value_type ~= "" then res = res .. ": " .. def.value_type end + return res .. ";" +end + +local function stringify_enum_item(def) + local res = def.name + if def.value_type and def.value_type ~= "" then res = res .. ": " .. def.value_type end + return res .. ";" +end + +-- Function to load file content into a temporary buffer, process it, and then delete the buffer +function RepoMap.stringify_definitions(filepath) + if vim.endswith(filepath, "~") then return "" end + + -- Extract definitions + local definitions = RepoMap.extract_definitions(filepath) + + local output = "" + -- Print or process the definitions + for _, def in ipairs(definitions) do + if def.type == "class" then + output = output .. def.type .. " " .. def.name .. " {\n" + for _, property in ipairs(def.properties) do + output = output .. " " .. stringify_variable(property) .. "\n" + end + for _, method in ipairs(def.methods) do + output = output .. " " .. stringify_function(method) .. "\n" + end + output = output .. "}\n" + elseif def.type == "enum" then + output = output .. def.type .. " " .. def.name .. " {\n" + for _, item in ipairs(def.items) do + output = output .. " " .. stringify_enum_item(item) .. "\n" + end + output = output .. "}\n" + elseif def.type == "function" then + output = output .. stringify_function(def) .. "\n" + elseif def.type == "variable" then + output = output .. stringify_variable(def) .. "\n" + end + end + + return output +end + +function RepoMap._build_repo_map(project_root, file_ext) + local Utils = require("avante.utils") + local output = {} + local gitignore_path = project_root .. "/.gitignore" + local ignore_patterns, negate_patterns = Utils.parse_gitignore(gitignore_path) + local filepaths = Utils.scan_directory(project_root, ignore_patterns, negate_patterns) + vim.iter(filepaths):each(function(filepath) + if not Utils.is_same_file_ext(file_ext, filepath) then return end + local definitions = RepoMap.stringify_definitions(filepath) + if definitions == "" then return end + table.insert(output, { + path = Utils.relative_path(filepath), + lang = RepoMap.get_filetype_by_filepath(filepath), + defs = definitions, + }) + end) + return output +end + +local cache = {} + +function RepoMap.get_repo_map() + if not Config.behaviour.enable_project_context then return nil end + local Utils = require("avante.utils") + local project_root = Utils.root.get() + local file_ext = vim.fn.expand("%:e") + local cache_key = project_root .. "." .. file_ext + local cached = cache[cache_key] + if cached then return cached end + + local PPath = require("plenary.path") + local Path = require("avante.path") + local repo_map + + local function build_and_save() + repo_map = RepoMap._build_repo_map(project_root, file_ext) + cache[cache_key] = repo_map + Path.repo_map.save(project_root, file_ext, repo_map) + end + + repo_map = Path.repo_map.load(project_root, file_ext) + + if not repo_map or next(repo_map) == nil then + build_and_save() + if not repo_map then return end + else + local timer = vim.loop.new_timer() + + if timer then + timer:start( + 0, + 0, + vim.schedule_wrap(function() + build_and_save() + timer:close() + end) + ) + end + end + + local update_repo_map = vim.schedule_wrap(function(rel_filepath) + if rel_filepath and Utils.is_same_file_ext(file_ext, rel_filepath) then + local abs_filepath = PPath:new(project_root):joinpath(rel_filepath):absolute() + local definitions = RepoMap.stringify_definitions(abs_filepath) + if definitions == "" then return end + local found = false + for _, m in ipairs(repo_map) do + if m.path == rel_filepath then + m.defs = definitions + found = true + break + end + end + if not found then + table.insert(repo_map, { + path = Utils.relative_path(abs_filepath), + lang = RepoMap.get_filetype_by_filepath(abs_filepath), + defs = definitions, + }) + end + cache[cache_key] = repo_map + Path.repo_map.save(project_root, file_ext, repo_map) + end + end) + + local handle = vim.loop.new_fs_event() + + if handle then + handle:start(project_root, { recursive = true }, function(err, rel_filepath) + if err then + print("Error watching directory " .. project_root .. ":", err) + return + end + + if rel_filepath then update_repo_map(rel_filepath) end + end) + end + + vim.api.nvim_create_autocmd({ "BufReadPost", "BufNewFile" }, { + callback = function(ev) + vim.defer_fn(function() + local filepath = vim.api.nvim_buf_get_name(ev.buf) + if not vim.startswith(filepath, project_root) then return end + local rel_filepath = Utils.relative_path(filepath) + update_repo_map(rel_filepath) + end, 0) + end, + }) + + return repo_map +end + +return RepoMap