diff --git a/after/plugin/gp.lua b/after/plugin/gp.lua new file mode 100644 index 0000000..535e629 --- /dev/null +++ b/after/plugin/gp.lua @@ -0,0 +1 @@ +require("gp.completion").register_cmd_source() diff --git a/data/ts_queries/lua.scm b/data/ts_queries/lua.scm new file mode 100644 index 0000000..a50326d --- /dev/null +++ b/data/ts_queries/lua.scm @@ -0,0 +1,45 @@ +;; Matches global and local function declarations +;; function a_fn_name() +;; +;; This will only match on top level functions. +;; Specificadlly, this ignores the local function declarations. +;; We're only doing this because we're requiring the +;; (file, function_name) pair to be unique in the database. +((chunk + (function_declaration + name: (identifier) @name) @body) + (#set! "type" "function")) + +;; Matches function declaration using the dot syntax +;; function a_table.a_fn_name() +((chunk + (function_declaration + name: (dot_index_expression) @name) @body) + (#set! "type" "function")) + +;; Matches function declaration using the member function syntax +;; function a_table:a_fn_name() +((chunk + (function_declaration + name: (method_index_expression) @name) @body) + (#set! "type" "function")) + +;; Matches on: +;; M.some_field = function() end +((chunk + (assignment_statement + (variable_list + name: (dot_index_expression) @name) + (expression_list + value: (function_definition) @body))) + (#set! "type" "function")) + +;; Matches on: +;; some_var = function() end +((chunk + (assignment_statement + (variable_list + name: (identifier) @name) + (expression_list + value: (function_definition) @body))) + (#set! "type" "function")) diff --git a/data/ts_queries/python.scm b/data/ts_queries/python.scm new file mode 100644 index 0000000..f7569d2 --- /dev/null +++ b/data/ts_queries/python.scm @@ -0,0 +1,20 @@ +;; Top level function definitions +((module + (function_definition + name: (identifier) @name ) @body + (#not-has-ancestor? @body class_definition)) + (#set! "type" "function")) + +;; Class member function definitions +((class_definition + name: (identifier) @classname + body: (block + (function_definition + name: (identifier) @name ) @body)) + (#set! "type" "class_method")) + + +;; Class definitions +((class_definition + name: (identifier) @name) @body + (#set! "type" "class")) diff --git a/lua/gp/completion.lua b/lua/gp/completion.lua new file mode 100644 index 0000000..12ab28a --- /dev/null +++ b/lua/gp/completion.lua @@ -0,0 +1,226 @@ +local u = require("gp.utils") +local context = require("gp.context") +local db = require("gp.db") +local cmp = require("cmp") + +---@class CompletionSource +---@field db Db +local source = {} + +source.src_name = "gp_completion" + +---@return CompletionSource +function source.new() + local db_inst = db.open() + return setmetatable({ db = db_inst }, { __index = source }) +end + +function source.get_trigger_characters() + return { "@", ":", "/" } +end + +-- Attaches the completion source to the given `bufnr` +function source.setup_for_chat_buffer(bufnr) + -- Don't attach the completion source if it's already been done + local attached_varname = "gp_source_attached" + if vim.b[attached_varname] then + return + end + + -- Attach the completion source + local config = require("cmp.config") + config.set_buffer({ + sources = { + { name = source.src_name }, + }, + }, bufnr) + + -- Set a flag so we don't try to set the source again + vim.b[attached_varname] = true +end + +function source.register_cmd_source() + cmp.register_source(source.src_name, source.new()) +end + +local function extract_cmd(request) + local target = request.context.cursor_before_line + local start = target:match(".*()@") + if start then + return string.sub(target, start, request.offset) + end +end + +local function completion_items_for_path(path) + -- The incoming path should either be + -- - A relative path that references a directory + -- - A relative path + partial filename as last component- + -- We need a bit of logic to figure out which directory content to return + + -------------------------------------------------------------------- + -- Figure out the full path of the directory we're trying to list -- + -------------------------------------------------------------------- + -- Split the path into component parts + local path_parts = u.path_split(path) + if path[#path] ~= "/" then + table.remove(path_parts) + end + + -- Assuming the cwd is the project root directory... + local cwd = vim.fn.getcwd() + local target_dir = u.path_join(cwd, unpack(path_parts)) + + -------------------------------------------- + -- List the items in the target directory -- + -------------------------------------------- + local handle = vim.loop.fs_scandir(target_dir) + local files = {} + + if not handle then + return files + end + + while true do + local name, type = vim.loop.fs_scandir_next(handle) + if not name then + break + end + + local item_name, item_kind + if type == "file" then + item_kind = cmp.lsp.CompletionItemKind.File + item_name = name + elseif type == "directory" then + item_kind = cmp.lsp.CompletionItemKind.Folder + item_name = name .. "/" + end + + table.insert(files, { + label = item_name, + kind = item_kind, + }) + end + + return files +end + +function source:completion_items_for_fn_name(partial_fn_name) + local result = self.db:find_symbol_by_name(partial_fn_name) + + local items = {} + if not result then + return items + end + + for _, row in ipairs(result) do + local item = { + -- fields meant for nvim-cmp + label = row.name, + labelDetails = { + detail = row.file, + }, + + -- fields meant for internal use + row = row, + type = "@code", + } + + if row.type == "class" then + item.kind = cmp.lsp.CompletionItemKind.Class + elseif row.type == "class_method" then + item.kind = cmp.lsp.CompletionItemKind.Method + else + item.kind = cmp.lsp.CompletionItemKind.Function + end + + table.insert(items, item) + end + + return items +end + +function source.complete(self, request, callback) + local input = string.sub(request.context.cursor_before_line, request.offset - 1) + local cmd = extract_cmd(request) + if not cmd then + return + end + + local cmd_parts = context.cmd_split(cmd) + + local items = {} + local isIncomplete = true + local cmd_type = cmd_parts[1] + + if cmd_type:match("@file") or cmd_type:match("@include") then + -- What's the path we're trying to provide completion for? + local path = cmd_parts[2] or "" + + -- List the items in the specified directory + items = completion_items_for_path(path) + + -- Say that the entire list has been provided + -- cmp won't call us again to provide an updated list + isIncomplete = false + elseif cmd_type:match("@code") then + local partial_fn_name = cmd_parts[2] or "" + + -- When the user confirms completion of an item, we alter the + -- command to look like `@code:path/to/file:fn_name` to uniquely + -- identify a function. + -- + -- If the user were to hit backspace to delete through the text, + -- don't process the input until it no longer looks like a path. + if partial_fn_name:match("/") then + return + end + + items = self:completion_items_for_fn_name(partial_fn_name) + isIncomplete = false + elseif input:match("^@") then + items = { + { label = "code", kind = require("cmp").lsp.CompletionItemKind.Keyword }, + { label = "file", kind = require("cmp").lsp.CompletionItemKind.Keyword }, + { label = "include", kind = require("cmp").lsp.CompletionItemKind.Keyword }, + } + isIncomplete = false + else + isIncomplete = false + end + + local data = { items = items, isIncomplete = isIncomplete } + callback(data) +end + +local function search_backwards(buf, pattern) + -- Use nvim_buf_call to execute a Vim command in the buffer context + return vim.api.nvim_buf_call(buf, function() + -- Search backwards for the pattern + local result = vim.fn.searchpos(pattern, "bn") + + if result[1] == 0 and result[2] == 0 then + return nil + end + return result + end) +end + +function source:execute(item, callback) + if item.type == "@code" then + -- Locate where @command starts and ends + local end_pos = vim.api.nvim_win_get_cursor(0) + local start_pos = search_backwards(0, "@code") + + -- Replace it with a custom piece of text and move the cursor to the end of the string + local text = string.format("@code:%s:%s", item.row.file, item.row.name) + vim.api.nvim_buf_set_text(0, start_pos[1] - 1, start_pos[2] - 1, end_pos[1] - 1, end_pos[2], { text }) + vim.api.nvim_win_set_cursor(0, { start_pos[1], start_pos[2] - 1 + #text }) + end + + -- After brief glance at the nvim-cmp source, it appears + -- we should call `callback` to continue the entry item selection + -- confirmation handling chain. + callback() +end + +return source diff --git a/lua/gp/context.lua b/lua/gp/context.lua new file mode 100644 index 0000000..53ea5cd --- /dev/null +++ b/lua/gp/context.lua @@ -0,0 +1,612 @@ +local u = require("gp.utils") +local gp = require("gp") +local logger = require("gp.logger") + +---@type Db +local Db = require("gp.db") + +local Context = {} + +-- Split a context insertion command into its component parts +-- This function will split the cmd by ":", at most into 3 parts. +-- It will grab the first 2 substrings that's split by ":", then +-- grab whatever is remaining as the 3rd string. +-- +-- Example: +-- cmd = "@code:/some/path/goes/here:class:fn_name" +-- => {"@code", "/some/path/goes/here", "class:fn_name"} +-- +-- This is can be used to split both @file and @code commands. +function Context.cmd_split(cmd) + local result = {} + local splits = u.string_find_all_substr(cmd, ":") + + local cursor = 0 + for i, split in ipairs(splits) do + if i > 2 then + break + end + local next_start = split[1] - 1 + local next_end = split[2] + table.insert(result, string.sub(cmd, cursor, next_start)) + cursor = next_end + 1 + end + + if cursor < #cmd then + table.insert(result, string.sub(cmd, cursor)) + end + + return result +end + +---@return string | nil +local function read_file(filepath) + local file = io.open(filepath, "r") + if not file then + return nil + end + local content = file:read("*all") + file:close() + return content +end + +local function file_exists(path) + local file = io.open(path, "r") + if file then + file:close() + return true + else + return false + end +end + +local function get_file_lines(filepath, start_line, end_line) + local lines = {} + local current_line = 0 + + -- Open the file for reading + local file = io.open(filepath, "r") + if not file then + logger.info("[get_file_lines] Could not open file: " .. filepath) + return nil + end + + for line in file:lines() do + if current_line >= start_line then + table.insert(lines, line) + end + if current_line > end_line then + break + end + current_line = current_line + 1 + end + + file:close() + + return lines +end + +-- Given a single message, parse out all the context insertion +-- commands, then return a new message with all the requested +-- context inserted +---@param msg string +function Context.insert_contexts(msg) + local context_texts = {} + + -- Parse out all context insertion commands + local cmds = {} + for cmd in msg:gmatch("@file:[%w%p]+") do + table.insert(cmds, cmd) + end + for cmd in msg:gmatch("@include:[%w%p]+") do + table.insert(cmds, cmd) + end + for cmd in msg:gmatch("@code:[%w%p]+[:%w_-]+") do + table.insert(cmds, cmd) + end + + local db = nil + + -- Process each command and turn it into a string be + -- inserted as additional context + for _, cmd in ipairs(cmds) do + local cmd_parts = Context.cmd_split(cmd) + local cmd_type = cmd_parts[1] + + if cmd_type == "@file" or cmd_type == "@include" then + -- Read the reqested file and produce a msg snippet to be joined later + local filepath = cmd_parts[2] + + local cwd = vim.fn.getcwd() + local fullpath = u.path_join(cwd, filepath) + + local content = read_file(fullpath) + if content then + local result + if cmd_type == "@file" then + result = string.format("%s\n```%s```", filepath, content) + else + result = content + end + table.insert(context_texts, result) + end + elseif cmd_type == "@code" then + local rel_path = cmd_parts[2] + local full_fn_name = cmd_parts[3] + if not rel_path or not full_fn_name then + goto continue + end + if db == nil then + db = Db.open() + end + + local fn_def = db:find_symbol_by_file_n_name(rel_path, full_fn_name) + if not fn_def then + logger.warning(string.format("Unable to locate function: '%s', '%s'", rel_path, full_fn_name)) + goto continue + end + + local fn_body = get_file_lines(fn_def.file, fn_def.start_line, fn_def.end_line) + if fn_body then + local result = string.format( + "In '%s', function '%s'\n```%s```", + fn_def.file, + fn_def.name, + table.concat(fn_body, "\n") + ) + table.insert(context_texts, result) + end + end + ::continue:: + end + + if db then + db:close() + end + + -- If no context insertions are requested, don't alter the original msg + if #context_texts == 0 then + return msg + else + -- Otherwise, build and return the final message + return string.format("%s\n\n%s", table.concat(context_texts, "\n"), msg) + end +end + +function Context.find_plugin_path(plugin_name) + local paths = vim.api.nvim_list_runtime_paths() + for _, path in ipairs(paths) do + local components = u.path_split(path) + if components[#components] == plugin_name then + return path + end + end +end + +-- Runs the supplied query on the supplied source file. +-- Returns all the captures as is. It is up to the caller to +-- know what the expected output is and to reshape the data. +---@param src_filepath string relative or full path to the src file to run the query on +---@param query_filepath string relative or full path to the query file to run +function Context.treesitter_query(src_filepath, query_filepath) + -- Read the source file content + ---WARNING: This is probably not a good idea for very large files + local src_content = read_file(src_filepath) + if not src_content then + logger.error("Unable to load src file: " .. src_filepath) + return nil + end + + -- Read the query file content + local query_content = read_file(query_filepath) + if not query_content then + logger.error("Unable to load query file: " .. query_filepath) + return nil + end + + -- Get the filetype of the source file + local filetype = vim.filetype.match({ filename = src_filepath }) + if not filetype then + logger.error("Unable to determine filetype for: " .. src_filepath) + return nil + end + + -- Check if the treesitter support for the language is available + local ok, err = pcall(vim.treesitter.language.add, filetype) + if not ok then + logger.error("TreeSitter parser for " .. filetype .. " is not installed") + logger.error(err) + return nil + end + + -- Parse the source text + -- local parser = vim.treesitter.get_parser(0, filetype) + local parser = vim.treesitter.get_string_parser(src_content, filetype, {}) + local tree = parser:parse()[1] + local root = tree:root() + + -- Create and run the query + local query = vim.treesitter.query.parse(filetype, query_content) + + -- Grab all the captures + local captures = {} + for id, node, metadata in query:iter_captures(root, src_content, 0, -1) do + local name = query.captures[id] + local start_row, start_col, end_row, end_col = node:range() + table.insert(captures, { + name = name, + node = node, + range = { start_row, start_col, end_row, end_col }, + text = vim.treesitter.get_node_text(node, src_content), + metadata = metadata, + }) + end + + return captures +end + +function Context.treesitter_extract_function_defs(src_filepath) + -- Make sure we can locate the source file + if not file_exists(src_filepath) then + logger.error("Unable to locate src file: " .. src_filepath) + return nil + end + + -- Get the filetype of the source file + local filetype = vim.filetype.match({ filename = src_filepath }) + if not filetype then + logger.error("Unable to determine filetype for: " .. src_filepath) + return nil + end + + -- We'll use the reported filetype as the name of the language + -- Try to locate a query file we can use to extract function definitions + local plugin_path = Context.find_plugin_path("gp.nvim") + if not plugin_path then + logger.error("Unable to locate path for gp.nvim...") + return nil + end + + -- Find the query file that's approprite for the language + local query_filepath = u.path_join(plugin_path, "data/ts_queries/" .. filetype .. ".scm") + if not file_exists(query_filepath) then + logger.debug("Unable to find function extraction ts query file: " .. query_filepath) + return nil + end + + -- Run the query + local captures = Context.treesitter_query(src_filepath, query_filepath) + if not captures then + return nil + end + + -- The captures are usually returned as a flat list with no way to tell + -- which captures came from the same symbol. But, if the query has attached + -- a some metadata to the query, all captured elements will reference the same metadata + -- table. We can then use this to correctly gather those elements into the same groups. + local function get_meta(x) + return x.metadata + end + captures = u.sort_by(get_meta, captures) + local groups = u.partition_by(get_meta, captures) + + -- Reshape the captures into a structure we'd like to work with + local results = {} + for _, group in ipairs(groups) do + local grp = {} + for _, item in ipairs(group) do + grp[item.name] = item + end + grp.metadata = group[1].metadata + + local type = grp.metadata.type + local item + if type == "function" then + item = { + file = src_filepath, + type = "function", + name = grp.name.text, + start_line = grp.body.range[1], + end_line = grp.body.range[3], + body = grp.body.text, -- for diagnostics + } + elseif type == "class_method" then + item = { + file = src_filepath, + type = "class_method", + name = string.format("%s.%s", grp.classname.text, grp.name.text), + start_line = grp.body.range[1], + end_line = grp.body.range[3], + body = grp.body.text, + } + elseif type == "class" then + item = { + file = src_filepath, + type = "class", + name = grp.name.text, + start_line = grp.body.range[1], + end_line = grp.body.range[3], + body = grp.body.text, + } + end + + item.body = nil -- Remove the diagnostics field to prep the entry for db insertion + table.insert(results, item) + end + + -- For debugging and manually checking the output + -- results = u.sort_by(function(x) + -- return x.start_line + -- end, results) + -- u.write_file("results.data.lua", vim.inspect(results)) + + return results +end + +---@param db Db +---@param src_filepath string +---@param generation? number +function Context.build_symbol_index_for_file(db, src_filepath, generation) + -- try to retrieve function definitions from the file + local symbols_list = Context.treesitter_extract_function_defs(src_filepath) + if not symbols_list then + return false + end + + -- Grab the src file meta data + local src_file_entry = db.collect_src_file_data(src_filepath) + if not src_file_entry then + logger.error("Unable to collect src file data for:" .. src_filepath) + return false + end + src_file_entry.last_scan_time = os.time() + src_file_entry.generation = generation + + -- Update the src file entry and the function definitions in a single transaction + local result = db:with_transaction(function() + local success = db:upsert_src_file(src_file_entry) + if not success then + logger.error("Upserting src_file failed") + return false + end + + success = db:upsert_and_clean_symbol_list_for_file(src_file_entry.filename, symbols_list) + if not success then + logger.error("Upserting symbol list failed") + return false + end + + return true + end) + return result +end + +local function make_gitignore_fn(git_root) + local base_paths = { git_root } + local allow = require("plenary.scandir").__make_gitignore(base_paths) + + return function(entry, rel_path, full_path, is_dir) + if entry == ".git" or entry == ".github" then + return false + end + if allow then + return allow(base_paths, full_path) + end + return true + end +end + +function Context.build_symbol_index(db) + local git_root = u.git_root_from_cwd() + if not git_root then + logger.error("[Context.build_symbol_index] Unable to locate project root") + return false + end + + local generation = u.random_8byte_int() + + u.walk_directory(git_root, { + should_process = make_gitignore_fn(git_root), + + process_file = function(rel_path, full_path) + if vim.filetype.match({ filename = full_path }) then + local success = Context.build_symbol_index_for_file(db, rel_path, generation) + if not success then + logger.debug("Failed to build function def index for: " .. rel_path) + end + end + end, + }) + + db.db:eval([[DELETE FROM src_files WHERE generation != ?]], { generation }) +end + +local ChangeResult = { + UNCHANGED = 0, + CHANGED = 1, + NOT_IN_LAST_SCAN = 2, +} + +-- Answers if the gien file seem to have changed since last scan +---@param db Db +---@param rel_path string +local function file_changed_since_last_scan(db, rel_path) + local cur = Db.collect_src_file_data(rel_path) + assert(cur) + + ---@type boolean|SrcFileEntry + local prev = db.db:eval([[SELECT * from src_files WHERE filename = ?]], { rel_path }) + if not prev then + return ChangeResult.NOT_IN_LAST_SCAN + end + + if cur.mod_time > prev.mod_time or cur.file_size ~= prev.file_size then + return ChangeResult.CHANGED + end + + return ChangeResult.UNCHANGED +end + +function Context.rebuild_symbol_index_for_changed_files(db) + local git_root = u.git_root_from_cwd() + if not git_root then + logger.error("[Context.build_symbol_index] Unable to locate project root") + return false + end + + local generation = u.random_8byte_int() + + u.walk_directory(git_root, { + should_process = make_gitignore_fn(git_root), + + process_file = function(rel_path, full_path) + if vim.filetype.match({ filename = full_path }) then + local status = file_changed_since_last_scan(db, rel_path) + if status == ChangeResult.UNCHANGED then + -- Even if the file did not change, we still want to mark the entry with the current generation ID + db.db:eval([[UPDATE src_files SET generation = ? WHERE filename = ?]], { generation, rel_path }) + return + end + local success = Context.build_symbol_index_for_file(db, rel_path, generation) + if not success then + logger.debug("Failed to build function def index for: " .. rel_path) + end + end + end, + }) + + db.db:eval([[DELETE FROM src_files WHERE generation != ?]], { generation }) +end + +function Context.index_single_file(src_filepath) + local db = Db.open() + if not db then + return + end + Context.build_symbol_index_for_file(db, src_filepath) + db:close() +end + +function Context.index_stale() + local uv = vim.uv or vim.loop + local start_time = uv.hrtime() + + local db = Db.open() + if not db then + return + end + Context.rebuild_symbol_index_for_changed_files(db) + db:close() + + local end_time = uv.hrtime() + local elapsed_time_ms = (end_time - start_time) / 1e6 + logger.info(string.format("Indexing took: %.2f ms", elapsed_time_ms)) +end + +function Context.index_all() + local uv = vim.uv or vim.loop + local start_time = uv.hrtime() + + local db = Db.open() + if not db then + return + end + Context.build_symbol_index(db) + db:close() + + local end_time = uv.hrtime() + local elapsed_time_ms = (end_time - start_time) / 1e6 + logger.info(string.format("Indexing took: %.2f ms", elapsed_time_ms)) +end + +function Context.build_initial_index() + local db = Db.open() + if not db then + return + end + + if db:get_metadata("done_initial_run") then + return + end + + Context.index_all() + db:set_metadata("done_initial_run", true) + db:close() +end + +function Context.setup_autocmd_update_index_periodically(bufnr) + local rebuild_time_var = "gp_next_rebuild_time" + local rebuild_period = 60 + u.buf_set_var(bufnr, rebuild_time_var, os.time() + rebuild_period) + + vim.api.nvim_create_autocmd("BufEnter", { + buffer = bufnr, + callback = function(arg) + local build_time = u.buf_get_var(arg.buf, rebuild_time_var) + if os.time() > build_time then + Context.index_stale() + u.buf_set_var(arg.buf, rebuild_time_var, os.time() + rebuild_period) + end + end, + }) +end + +-- Setup autocommand to update the function def index as the files are saved +function Context.setup_autocmd_update_index_on_file_save() + vim.api.nvim_create_autocmd("BufWritePost", { + pattern = { "*" }, + group = vim.api.nvim_create_augroup("GpFileIndexUpdate", { clear = true }), + callback = function(arg) + Context.index_single_file(arg.file) + end, + }) +end + +function Context.setup_for_chat_buffer(buf) + Context.build_initial_index() + Context.setup_autocmd_update_index_periodically(buf) + require("gp.completion").setup_for_chat_buffer(buf) +end + +-- Inserts the reference to the function under the cursor into the chat buffer +function Context.reference_current_function() + local db = Db.open() + if not db then + return + end + + local buf = vim.api.nvim_get_current_buf() + local rel_path = vim.fn.bufname(buf) + local lineno = math.max(vim.api.nvim_win_get_cursor(0)[1] - 1, 0) + + ---@type boolean|SymbolDefEntry + local res = db.db:eval( + [[ SELECT * from symbols + WHERE + file = ? AND + start_line <= ? AND + end_line >= ? ]], + { rel_path, lineno, lineno } + ) + + db:close() + + if type(res) == "boolean" then + logger.error("[context.reference_current_function] Symbol lookup returned unexpected value: " .. res) + return + end + + local entry = res[1] + + require("gp").chat_paste(string.format("@code:%s:%s", entry.file, entry.name)) +end + +function Context.reference_current_file() + local buf = vim.api.nvim_get_current_buf() + local rel_path = vim.fn.bufname(buf) + require("gp").chat_paste(string.format("@file:%s", rel_path)) +end + +Context.setup_autocmd_update_index_on_file_save() + +return Context diff --git a/lua/gp/db.lua b/lua/gp/db.lua new file mode 100644 index 0000000..77d6b97 --- /dev/null +++ b/lua/gp/db.lua @@ -0,0 +1,390 @@ +local sqlite = require("sqlite.db") +local sqlite_clib = require("sqlite.defs") +local gp = require("gp") +local u = require("gp.utils") +local logger = require("gp.logger") + +-- Describes files we've scanned previously to produce the list of symbols +---@class SrcFileEntry +---@field id number: unique id +---@field filename string: path relative to the git/project root +---@field file_size number: -- zie of the file at last scan in bytes +---@field filetype string: filetype as reported by neovim at last scan +---@field mod_time number: last file modification time reported by the os at last scan +---@field last_scan_time number: unix time stamp indicating when the last scan of this file was made +---@field generation? number: For internal use - garbage collection + +-- Describes where each of the functions are in the project +---@class SymbolDefEntry +---@field id number: unique id +---@field file string: Which file is the symbol defined? +---@field name string: Name of the symbol +---@field type string: type of the symbol +---@field start_line number: Which line in the file does the definition start? +---@field end_line number: Which line in the file does the definition end? +---@field generation? number: For internal use - garbage collection + +---@class Db +---@field db sqlite_db +local Db = {} + +--- @return Db +Db._new = function(db) + return setmetatable({ db = db }, { __index = Db }) +end + +--- Opens and/or creates a SQLite database for storing symbol definitions. +-- @return Db|nil A new Db object if successful, nil if an error occurs +-- @side-effect Creates .gp directory and database file if they don't exist +-- @side-effect Logs errors if unable to locate project root or create directory +function Db.open() + local git_root = u.git_root_from_cwd() + if git_root == "" then + logger.error("[db.open] Unable to locate project root") + return nil + end + + local db_file = u.path_join(git_root, ".gp/index.sqlite") + if not u.ensure_parent_path_exists(db_file) then + logger.error("[db.open] Unable create directory for db file: " .. db_file) + return nil + end + + ---@type sqlite_db + local db = sqlite({ + uri = db_file, + + -- The `metadata` table is a simple KV store + metadata = { + id = true, + key = { type = "text", required = true, unique = true }, + value = { type = "luatable", required = true }, + }, + + -- The `src_files` table stores a list of known src files and the last time they were scanned + src_files = { + id = true, + filename = { type = "text", required = true }, -- relative to the git/project root + file_size = { type = "integer", required = true }, -- size of the file at last scan + filetype = { type = "text", required = true }, -- filetype as reported by neovim at last scan + mod_time = { type = "integer", required = true }, -- file mod time reported by the fs at last scan + last_scan_time = { type = "integer", required = true }, -- unix timestamp + generation = { type = "integer" }, -- for garbage collection + }, + + symbols = { + id = true, + file = { type = "text", require = true, reference = "src_files.filename", on_delete = "cascade" }, + name = { type = "text", required = true }, + type = { type = "text", required = true }, + start_line = { type = "integer", required = true }, + end_line = { type = "integer", required = true }, + generation = { type = "integer" }, -- for garbage collection + }, + + opts = { keep_open = true }, + }) + + db:eval("CREATE UNIQUE INDEX IF NOT EXISTS idx_src_files_filename ON src_files (filename);") + db:eval("CREATE UNIQUE INDEX IF NOT EXISTS idx_symbol_file_n_name ON symbols (file, name);") + + return Db._new(db) +end + +--- Gathers information on a file to populate most of a SrcFileEntry. +--- @return SrcFileEntry|nil +function Db.collect_src_file_data(relative_path) + local uv = vim.uv or vim.loop + + -- Construct the full path to the file + local proj_root = u.git_root_from_cwd() + local fullpath = u.path_join(proj_root, relative_path) + + -- If the file doesn't exist, there is nothing to collect + local stat = uv.fs_stat(fullpath) + if not stat then + logger.debug("[Db.collection_src_file_data] failed: " .. relative_path) + return nil + end + + local entry = {} + + entry.filename = relative_path + entry.file_size = stat.size + entry.filetype = vim.filetype.match({ filename = fullpath }) + entry.mod_time = stat.mtime.sec + + return entry +end + +-- Upserts a single src file entry into the database +--- @param file SrcFileEntry +function Db:upsert_src_file(file) + if not self.db then + logger.error("[db.upsert_src_file] Database not initialized") + return false + end + + local sql = [[ + INSERT INTO src_files (filename, file_size, filetype, mod_time, last_scan_time, generation) + VALUES (?, ?, ?, ?, ?, ?) + ON CONFLICT(filename) DO UPDATE SET + file_size = excluded.file_size, + filetype = excluded.filetype, + mod_time = excluded.mod_time, + last_scan_time = excluded.last_scan_time, + generation = excluded.generation + WHERE filename = ? + ]] + + local success = self.db:eval(sql, { + -- For the INSERT VALUES clause + file.filename, + file.file_size, + file.filetype, + file.mod_time, + file.last_scan_time, + file.generation or -1, + + -- For the WHERE claue + file.filename, + }) + + if not success then + logger.error("[db.upsert_src_file] Failed to upsert file: " .. file.filename) + return false + end + + return true +end + +--- @param filelist SrcFileEntry[] +function Db:upsert_filelist(filelist) + for _, file in ipairs(filelist) do + local success = self:upsert_src_file(file) + if not success then + logger.error("[db.upsert_filelist] Failed to upsert file list") + return false + end + end + + return true +end + +-- Upserts a single symbol entry into the database +--- @param def SymbolDefEntry +function Db:upsert_symbol(def) + if not self.db then + logger.error("[db.upsert_symbol] Database not initialized") + return false + end + + ---WARNING: Do not use ORM here. + -- This function can be called a lot during a full index rebuild. + -- Using the ORM here can cause a 100% slowdown. + local sql = [[ + INSERT INTO symbols (file, name, type, start_line, end_line, generation) + VALUES (?, ?, ?, ?, ?, ?) + ON CONFLICT(file, name) DO UPDATE SET + type = excluded.type, + start_line = excluded.start_line, + end_line = excluded.end_line, + generation = excluded.generation + WHERE file = ? AND name = ? + ]] + + local success = self.db:eval(sql, { + -- For the INSERT VALUES clause + def.file, + def.name, + def.type, + def.start_line, + def.end_line, + def.generation or -1, + + -- For the WHERE clause + def.file, + def.name, + }) + + if not success then + logger.error("[db.upsert_symbol] Failed to upsert symbol: " .. def.name .. " for file: " .. def.file) + return false + end + + return true +end + +-- Wraps the given function in a sqlite transaction +---@param fn function() +function Db:with_transaction(fn) + local success, result + + success = self.db:execute("BEGIN") + if not success then + logger.error("[db.with_transaction] Unable to start transaction") + return false + end + + success, result = pcall(fn) + if not success then + logger.error("[db.with_transaction] fn return false") + logger.error(result) + + success = self.db:execute("ROLLBACK") + if not success then + logger.error("[db.with_transaction] Rollback failed") + end + return false + end + + success = self.db:execute("COMMIT") + if not success then + logger.error("[db.with_transaction] Unable to end transaction") + return false + end + + return true +end + +--- @param symbols_list SymbolDefEntry[] +function Db:upsert_and_clean_symbol_list_for_file(src_rel_path, symbols_list) + -- Generate a random generation ID for all tne newly updated/refreshed items + local generation = u.random_8byte_int() + for _, item in ipairs(symbols_list) do + item.generation = generation + end + + -- Upsert all entries + local success = self:upsert_symbol_list(symbols_list) + if not success then + return success + end + + -- Remove all symbols in the file that does not hav the new generation ID + -- Those symbols are not present in the newly generated list and should be removed. + success = self.db:eval([[DELETE from symbols WHERE file = ? and generation != ? ]], { src_rel_path, generation }) + if not success then + logger.error("[db.insert_and_clean_symbol_list_for_file] Unable to clean up garbage") + return success + end + + return true +end + +--- Updates the dastabase with the contents of the `symbols_list` +--- Note that this function early terminates if any of the entry upsert fails. +--- This behavior is only suitable when run inside a transaction. +--- @param symbols_list SymbolDefEntry[] +function Db:upsert_symbol_list(symbols_list) + for _, def in ipairs(symbols_list) do + local success = self:upsert_symbol(def) + if not success then + logger.error("[db.upsert_fnlist] Failed to upsert function def list") + return false + end + end + + return true +end + +function Db:close() + self.db:close() +end + +function Db:find_symbol_by_name(partial_fn_name) + local sql = [[ + SELECT * FROM symbols WHERE name LIKE ? + ]] + + local wildcard_name = "%" .. partial_fn_name .. "%" + + local result = self.db:eval(sql, { + wildcard_name, + }) + + -- We're expecting the query to return a list of SymbolDefEntry. + -- If we get a boolean back instead, we consider the operation to have failed. + if type(result) == "boolean" then + return nil + end + + ---@cast result SymbolDefEntry + return result +end + +function Db:find_symbol_by_file_n_name(rel_path, full_fn_name) + local sql = [[ + SELECT * FROM symbols WHERE file = ? AND name = ? + ]] + + local result = self.db:eval(sql, { + rel_path, + full_fn_name, + }) + + -- We're expecting the query to return a list of SymbolDefEntry. + -- If we get a boolean back instead, we consider the operation to have failed. + if type(result) == "boolean" then + return nil + end + + ---@cast result SymbolDefEntry[] + if #result > 1 then + logger.error( + string.format( + "[Db.find_symbol_by_file_n_name] Found more than 1 result for: '%s', '%s'", + rel_path, + full_fn_name + ) + ) + end + + return result[1] +end + +-- Removes a single entry from the src_files table given a relative file path +-- Note that related entries in the symbols table will be removed via CASCADE. +---@param src_filepath string +function Db:remove_src_file_entry(src_filepath) + local sql = [[ + DELETE FROM src_files WHERE filename = ? + ]] + + local result = self.db:eval(sql, { + src_filepath, + }) + + return result +end + +function Db:clear() + self.db:eval("DELETE FROM symbols") + self.db:eval("DELETE FROM src_files") +end + +-- Gets the value of a key from the metadata table +---@param keyname string +---@return any +function Db:get_metadata(keyname) + local result = self.db.metadata:where({ key = keyname }) + if result then + return result.value + end +end + +-- Sets the value of a key in the metadata table +-- WARNING: value cannot be of a number type +---@param keyname string +---@param value any +function Db:set_metadata(keyname, value) + -- The sqlite.lua plugin doesn't seem to like having numbers stored in the a field + -- marked as the "luatable" or "json" type. + -- If we store a number into the value field, sqlite.lua will throw a parse error on get. + if type(value) == "number" then + error("database metadata table doesn't not support storing a number as a root value") + end + return self.db.metadata:update({ where = { key = keyname }, set = { value = value } }) +end + +return Db diff --git a/lua/gp/init.lua b/lua/gp/init.lua index 4034e8e..f4d1e43 100644 --- a/lua/gp/init.lua +++ b/lua/gp/init.lua @@ -208,6 +208,18 @@ M.setup = function(opts) end end + vim.api.nvim_create_user_command("GpRebuildIndex", function(_) + require("gp.context").index_all() + end, {}) + + vim.api.nvim_create_user_command("GpReferenceCurrentFunction", function(_) + require("gp.context").reference_current_function() + end, {}) + + vim.api.nvim_create_user_command("GpReferenceCurrentFile", function(_) + require("gp.context").reference_current_file() + end, {}) + M.buf_handler() if vim.fn.executable("curl") == 0 then @@ -824,15 +836,20 @@ M.cmd.ChatNew = function(params, system_prompt, agent) end -- if chat toggle is open, close it and start a new one + local buf if M._toggle_close(M._toggle_kind.chat) then params.args = params.args or "" if params.args == "" then params.args = M.config.toggle_target end - return M.new_chat(params, true, system_prompt, agent) + buf = M.new_chat(params, true, system_prompt, agent) + else + buf = M.new_chat(params, false, system_prompt, agent) end - return M.new_chat(params, false, system_prompt, agent) + require("gp.context").setup_for_chat_buffer(buf) + + return buf end ---@param params table @@ -853,16 +870,90 @@ M.cmd.ChatToggle = function(params, system_prompt, agent) end -- if the range is 2, we want to create a new chat file with the selection + local buf if params.range ~= 2 then local last = M._state.last_chat if last and vim.fn.filereadable(last) == 1 then last = vim.fn.resolve(last) - M.open_buf(last, M.resolve_buf_target(params), M._toggle_kind.chat, true) - return + buf = M.open_buf(last, M.resolve_buf_target(params), M._toggle_kind.chat, true) + end + else + buf = M.new_chat(params, true, system_prompt, agent) + end + + if buf then + require("gp.context").setup_for_chat_buffer(buf) + end + + return buf +end + +local function win_for_buf(bufnr) + for _, w in ipairs(vim.api.nvim_list_wins()) do + if vim.api.nvim_win_get_buf(w) == bufnr then + return w + end + end +end + +local function create_buffer_with_file(file_path) + -- Create a new buffer + local bufnr = vim.api.nvim_create_buf(true, false) + + -- Set the buffer's name to the file path + vim.api.nvim_buf_set_name(bufnr, file_path) + + -- Load the file into the buffer + vim.api.nvim_buf_call(bufnr, function() + vim.api.nvim_command("edit " .. vim.fn.fnameescape(file_path)) + end) + + return bufnr +end + +-- Paste some content into the chat buffer +M.chat_paste = function(content) + -- locate the chat buffer + local chat_buf + local last = M._state.last_chat + + ------------------------------------------------ + -- Try to locate or setup a valid chat buffer -- + ------------------------------------------------ + -- If don't have a record of the last chat file that's been opened... + -- Just create a new chat + if not last or vim.fn.filereadable(last) ~= 1 then + chat_buf = M.cmd.ChatNew({}, nil, nil) + else + -- We have a record of the last chat file... + -- Can we locate a buffer with the file loaded? + last = vim.fn.resolve(last) + chat_buf = M.helpers.get_buffer(last) + + if not chat_buf then + chat_buf = create_buffer_with_file(last) end end - M.new_chat(params, true, system_prompt, agent) + -------------------------------------------- + -- Paste the content into the chat buffer -- + -------------------------------------------- + if chat_buf then + -- Paste the given `content` at the end of the buffer + vim.api.nvim_buf_set_lines(chat_buf, -1, -1, false, { content }) + + -- If we can locate a window for the buffer... + -- Set the cursor to the end of the file where we just pasted the content + local win = win_for_buf(chat_buf) + if win then + local line_count = vim.api.nvim_buf_line_count(chat_buf) + vim.api.nvim_win_set_cursor(win, { line_count, 0 }) + + vim.api.nvim_win_call(win, function() + vim.api.nvim_command("normal! zz") + end) + end + end end M.cmd.ChatPaste = function(params) @@ -1074,6 +1165,10 @@ M.chat_respond = function(params) local last_content_line = M.helpers.last_content_line(buf) vim.api.nvim_buf_set_lines(buf, last_content_line, last_content_line, false, { "", agent_prefix .. agent_suffix, "" }) + -- insert requested context in the message the user just entered + messages[#messages].content = require("gp.context").insert_contexts(messages[#messages].content) + -- print(vim.inspect(messages[#messages])) + -- call the model and write response M.dispatcher.query( buf, diff --git a/lua/gp/utils.lua b/lua/gp/utils.lua new file mode 100644 index 0000000..4cfdd2e --- /dev/null +++ b/lua/gp/utils.lua @@ -0,0 +1,232 @@ +local uv = vim.uv or vim.loop + +local Utils = {} + +function Utils.path_split(path) + return vim.split(path, "/") +end + +function Utils.path_join(...) + local args = { ... } + local parts = {} + + for i, part in ipairs(args) do + if type(part) ~= "string" then + error("Argument #" .. i .. " is not a string", 2) + end + + -- Remove leading/trailing separators (both / and \) + part = part:gsub("^[/\\]+", ""):gsub("[/\\]+$", "") + + if #part > 0 then + table.insert(parts, part) + end + end + + local result = table.concat(parts, "/") + + if args[1]:match("^[/\\]") then + result = "/" .. result + end + + return result +end + +function Utils.path_is_absolute(path) + if Utils.string_starts_with(path, "/") then + return true + end + return false +end + +function Utils.ensure_path_exists(path) + -- Check if the path exists + local stat = uv.fs_stat(path) + if stat and stat.type == "directory" then + -- The path exists and is a directory + return true + end + + -- Try to create the directory + return vim.fn.mkdir(path, "p") +end + +function Utils.ensure_parent_path_exists(path) + local components = Utils.path_split(path) + + -- Get the parent directory by removing the last component + table.remove(components) + local parent_path = table.concat(components, "/") + + return Utils.ensure_path_exists(parent_path) +end + +function Utils.string_starts_with(str, starting) + return string.sub(str, 1, string.len(starting)) == starting +end + +function Utils.string_ends_with(str, ending) + if #ending > #str then + return false + end + + return str:sub(-#ending) == ending +end + +---@class WalkDirectoryOptions +---@field should_process function Passed `entry`, `rel_path`, `full_path`, and `is_dir` +---@field process_file function +---@field on_error function +---@field recurse boolean +---@field max_depth number +--- +---@param dir string The directory to try to walk +---@param options WalkDirectoryOptions Describes how to walk the directory +--- +function Utils.walk_directory(dir, options) + options = options or {} + + local should_process = options.should_process or function() + return true + end + + local process_file = options.process_file or function(rel_path, full_path) + print(full_path) + end + local recurse = not options.recurse + + ---@type number + local max_depth = options.max_depth or math.huge + + local function walk(current_dir, current_depth) + if current_depth > max_depth then + return + end + + local entries = vim.fn.readdir(current_dir) + + for _, entry in ipairs(entries) do + local full_path = Utils.path_join(current_dir, entry) + local rel_path = full_path:sub(#dir + 2) + local is_dir = vim.fn.isdirectory(full_path) == 1 + + if should_process(entry, rel_path, full_path, is_dir) then + if is_dir then + if recurse then + walk(full_path, current_depth + 1) + end + else + pcall(process_file, rel_path, full_path) + end + end + end + end + + walk(dir, 1) +end + +--- Locates the git_root using the cwd +function Utils.git_root_from_cwd() + return require("gp.helper").find_git_root(vim.fn.getcwd()) +end + +-- If the given path is a relative path, turn it into a fullpath +-- based on the current git_root +---@param path string +function Utils.full_path_for_project_file(path) + if Utils.path_is_absolute(path) then + return path + end + + -- Construct the full path to the file + local proj_root = Utils.git_root_from_cwd() + return Utils.path_join(proj_root, path) +end + +function Utils.string_find_all_substr(str, substr) + local result = {} + local first = 0 + local last = 0 + + while true do + first, last = str:find(substr, first + 1) + if not first then + break + end + table.insert(result, { first, last }) + end + return result +end + +function Utils.partition_by(pred, list) + local result = {} + local current_partition = {} + local last_key = nil + + for _, item in ipairs(list) do + local key = pred(item) + if last_key == nil or key ~= last_key then + if #current_partition > 0 then + table.insert(result, current_partition) + end + current_partition = {} + end + table.insert(current_partition, item) + last_key = key + end + + if #current_partition > 0 then + table.insert(result, current_partition) + end + + return result +end + +function Utils.write_file(filename, content, mode) + mode = mode or "w" -- Default mode is write + if not content then + return true + end + local file = io.open(filename, mode) + if file then + file:write(content) + file:close() + else + error("Unable to open file: " .. filename) + end + return true +end + +function Utils.sort_by(key_fn, tbl) + table.sort(tbl, function(a, b) + local ka, kb = key_fn(a), key_fn(b) + if type(ka) == "table" and type(kb) == "table" then + -- Use table identifiers as tie-breaker + return tostring(ka) < tostring(kb) + else + return ka < kb + end + end) + return tbl +end + +function Utils.random_8byte_int() + return math.random(0, 0xFFFFFFFFFFFFFFFF) +end + +-- Gets a buffer variable or returns the default +function Utils.buf_get_var(buf, var_name, default) + local status, result = pcall(vim.api.nvim_buf_get_var, buf, var_name) + if status then + return result + else + return default + end +end + +-- This function is only here make the get/set call pair look consistent +function Utils.buf_set_var(buf, var_name, value) + return vim.api.nvim_buf_set_var(buf, var_name, value) +end + +return Utils