From 07630d1443d67094e96a52786a3b61a46079ed6b Mon Sep 17 00:00:00 2001 From: Lewis Russell Date: Sat, 13 Jul 2024 10:36:27 +0100 Subject: [PATCH] refactor: update async module --- lua/pckr/actions.lua | 63 +++++++++---------- lua/pckr/async.lua | 104 ++++++++++++++++++-------------- lua/pckr/display.lua | 5 +- lua/pckr/jobs.lua | 4 +- lua/pckr/lockfile.lua | 18 +++--- lua/pckr/plugin_types.lua | 6 +- lua/pckr/plugin_types/git.lua | 36 +++++------ lua/pckr/plugin_types/local.lua | 5 +- lua/pckr/status.lua | 5 +- 9 files changed, 131 insertions(+), 115 deletions(-) diff --git a/lua/pckr/actions.lua b/lua/pckr/actions.lua index 6d2a09a..ada47d7 100644 --- a/lua/pckr/actions.lua +++ b/lua/pckr/actions.lua @@ -133,12 +133,13 @@ local function update_helptags(results) end end +--- @async --- @param plugin Pckr.Plugin --- @param disp Pckr.Display --- @return string? -local post_update_hook = a.sync(function(plugin, disp) +local function post_update_hook(plugin, disp) if plugin.run or plugin.start then - a.main() + a.schedule() local loader = require('pckr.loader') loader.load_plugin(plugin) end @@ -147,7 +148,7 @@ local post_update_hook = a.sync(function(plugin, disp) return end - a.main() + a.schedule() local run_task = plugin.run @@ -176,14 +177,14 @@ local post_update_hook = a.sync(function(plugin, disp) end end end -end, 2) +end --- @alias Pckr.Task fun(plugin: Pckr.Plugin, disp: Pckr.Display, cb: fun()): string?, string? --- @param plugin Pckr.Plugin --- @param disp Pckr.Display --- @return string, string? -local install_task = a.sync(function(plugin, disp) +local install_task = a.sync(2, function(plugin, disp) disp:task_start(plugin.name, 'installing...') local plugin_type = require('pckr.plugin_types')[plugin.type] @@ -209,13 +210,13 @@ local install_task = a.sync(function(plugin, disp) end return plugin.name, err -end, 2) +end) --- @param plugin Pckr.Plugin --- @param disp Pckr.Display --- @param __cb? function --- @return string?, string? -local update_task = a.sync(function(plugin, disp, __cb) +local update_task = a.sync(2, function(plugin, disp, __cb) disp:task_start(plugin.name, 'updating...') if plugin.lock then @@ -263,10 +264,10 @@ local update_task = a.sync(function(plugin, disp, __cb) end return plugin.name, plugin.err -end, 2) +end) --- Find and remove any plugins not currently configured for use -local do_clean = a.sync(function() +local do_clean = a.sync(0, function() log.debug('Starting clean') local to_remove = fsstate.find_extra_plugins(pckr_plugins) @@ -278,7 +279,7 @@ local do_clean = a.sync(function() return end - a.main() + a.schedule() local lines = {} --- @type string[] for path, _ in pairs(to_remove) do @@ -299,7 +300,7 @@ local do_clean = a.sync(function() log.fmt_warn('Could not remove %s', path) end end -end, 0) +end) --- @return Pckr.Plugin local function get_pckr_spec() @@ -334,7 +335,7 @@ local function sync(op, plugins) local missing = fsstate.find_missing_plugins(pckr_plugins) local missing_plugins, installed_plugins = util.partition(missing, plugins) - a.main() + a.schedule() local disp = open_display() @@ -342,21 +343,21 @@ local function sync(op, plugins) if install then log.debug('Gathering install tasks') local results = map_task(install_task, missing_plugins, disp, 'installing') - a.main() + a.schedule() update_helptags(results) end if update then log.debug('Gathering update tasks') local results = map_task(update_task, installed_plugins, disp, 'updating') - a.main() + a.schedule() update_helptags(results) end if upgrade then pckr_plugins['pckr.nvim'] = get_pckr_spec() local results = map_task(update_task, { 'pckr.nvim' }, disp, 'updating') - a.main() + a.schedule() update_helptags(results) end end) @@ -369,9 +370,9 @@ end --- @param plugins? string[] --- @param _opts table? --- @param __cb? function -M.install = a.sync(function(plugins, _opts, __cb) +M.install = a.sync(2, function(plugins, _opts, __cb) sync('install', plugins) -end, 2) +end) --- Update operation: --- Takes an optional list of plugin names as an argument. If no list is given, @@ -380,9 +381,9 @@ end, 2) --- @param plugins? string[] List of plugin names to update. --- @param _opts table? --- @param __cb? function -M.update = a.sync(function(plugins, _opts, __cb) +M.update = a.sync(2, function(plugins, _opts, __cb) sync('update', plugins) -end, 2) +end) --- Sync operation: --- Takes an optional list of plugin names as an argument. If no list is given, @@ -391,43 +392,43 @@ end, 2) --- @param plugins? string[] --- @param _opts table? --- @param __cb? function -M.sync = a.sync(function(plugins, _opts, __cb) +M.sync = a.sync(2, function(plugins, _opts, __cb) sync('sync', plugins) -end, 2) +end) -M.upgrade = a.sync(function(_, _opts, __cb) +M.upgrade = a.sync(2, function(_, _opts, __cb) sync('upgrade') -end, 2) +end) --- @param _ any --- @param _opts table? --- @param __cb? function -M.status = a.sync(function(_, _opts, __cb) +M.status = a.sync(2, function(_, _opts, __cb) require('pckr.status').run() -end, 2) +end) --- Clean operation: --- Finds plugins present in the `pckr` package but not in the managed set --- @param _ any --- @param _opts table? --- @param __cb? function -M.clean = a.sync(function(_, _opts, __cb) +M.clean = a.sync(2, function(_, _opts, __cb) do_clean() -end, 2) +end) --- @param _ any --- @param _opts table? --- @param __cb? function -M.lock = a.sync(function(_, _opts, __cb) +M.lock = a.sync(2, function(_, _opts, __cb) require('pckr.lockfile').lock() -end, 2) +end) --- @param _ any --- @param _opts table? --- @param __cb? function -M.restore = a.sync(function(_, _opts, __cb) +M.restore = a.sync(2 ,function(_, _opts, __cb) require('pckr.lockfile').restore() -end, 2) +end) --- @param _ any --- @param _opts table? diff --git a/lua/pckr/async.lua b/lua/pckr/async.lua index 76e2bcc..1a12ece 100644 --- a/lua/pckr/async.lua +++ b/lua/pckr/async.lua @@ -1,5 +1,3 @@ -local co = coroutine - local function validate_callback(func, callback) if callback and type(callback) ~= 'function' then local info = debug.getinfo(func, 'nS') @@ -17,36 +15,38 @@ end --- @param func function --- @param callback function? --- @param ... any -local function execute(func, callback, ...) +local function run(func, callback, ...) validate_callback(func, callback) - local thread = co.create(func) + local co = coroutine.create(func) local function step(...) - local ret = { co.resume(thread, ...) } - --- @type boolean, integer, function - local stat, nargs, fn_or_ret = unpack(ret) + local ret = { coroutine.resume(co, ...) } + local stat = ret[1] if not stat then + local err = ret[2] --[[@as string]] error( - string.format( - 'The coroutine failed with this message: %s\n%s', - nargs, - debug.traceback(thread) - ) + string.format('The coroutine failed with this message: %s\n%s', err, debug.traceback(co)) ) end - if co.status(thread) == 'dead' then + if coroutine.status(co) == 'dead' then if callback then - callback(unpack(ret, 2)) + callback(unpack(ret, 2, table.maxn(ret))) end return end - local args = { select(4, unpack(ret)) } + --- @type integer, fun(...: any): any + local nargs, fn = ret[2], ret[3] + + assert(type(fn) == 'function', 'type error :: expected func') + + --- @type any[] + local args = { unpack(ret, 4, table.maxn(ret)) } args[nargs] = step - fn_or_ret(unpack(args, 1, nargs)) + fn(unpack(args, 1, nargs)) end step(...) @@ -54,17 +54,45 @@ end local M = {} +--- @param argc integer +--- @param func function +--- @param ... any +--- @return any ... +function M.wait(argc, func, ...) + -- Always run the wrapped functions in xpcall and re-raise the error in the + -- coroutine. This makes pcall work as normal. + local function pfunc(...) + local args = { ... } --- @type any[] + local cb = args[argc] + args[argc] = function(...) + cb(true, ...) + end + xpcall(func, function(err) + cb(false, err, debug.traceback()) + end, unpack(args, 1, argc)) + end + + local ret = { coroutine.yield(argc, pfunc, ...) } + + local ok = ret[1] + if not ok then + --- @type string, string + local err, traceback = ret[2], ret[3] + error(string.format('Wrapped function failed: %s\n%s', err, traceback)) + end + + return unpack(ret, 2, table.maxn(ret)) +end + --- Creates an async function with a callback style function. ---- @generic F: function ---- @param func F --- @param argc integer ---- @return F -function M.wrap(func, argc) +--- @param func function +--- @return function +function M.wrap(argc, func) + assert(type(argc) == 'number') + assert(type(func) == 'function') return function(...) - if not co.running() or select('#', ...) == argc then - return func(...) - end - return co.yield(argc, func, ...) + return M.wait(argc, func, ...) end end @@ -72,30 +100,14 @@ end ---called from a non-async context. Inherently this cannot return anything ---since it is non-blocking --- @generic F: function +--- @param nargs integer --- @param func async F ---- @param nargs? integer --- @return F -function M.sync(func, nargs) - nargs = nargs or 0 +function M.sync(nargs, func) return function(...) - if co.running() then - return func(...) - end + assert(not coroutine.running()) local callback = select(nargs + 1, ...) - execute(func, callback, unpack({ ... }, 1, nargs)) - end -end - ---- For functions that don't provide a callback as there last argument ---- @generic F: function ---- @param func F ---- @return F -function M.void(func) - return function(...) - if co.running() then - return func(...) - end - execute(func, nil, ...) + run(func, callback, unpack({ ... }, 1, nargs)) end end @@ -105,7 +117,7 @@ end --- @param thunks (fun(cb: function): R)[] --- @return {[1]: R}[] function M.join(n, interrupt_check, thunks) - return co.yield(1, function(finish) + return coroutine.yield(1, function(finish) if #thunks == 0 then return finish() end @@ -154,6 +166,6 @@ end ---An async function that when called will yield to the Neovim scheduler to be ---able to call the API. --- @type fun() -M.main = M.wrap(vim.schedule, 1) +M.schedule = M.wrap(1, vim.schedule) return M diff --git a/lua/pckr/display.lua b/lua/pckr/display.lua index f7f7879..6d0cd0e 100644 --- a/lua/pckr/display.lua +++ b/lua/pckr/display.lua @@ -61,7 +61,7 @@ local function get_plugin(disp) end --- @param inner? boolean ---- @return vim.api.keyset.float_config +--- @return vim.api.keyset.win_config local function get_win_config(inner) local vpad = inner and 8 or 6 local hpad = inner and 14 or 10 @@ -392,6 +392,7 @@ local function toggle_info(disp) api.nvim_win_set_cursor(disp.win, cursor_pos) end +--- @async --- Utility function to prompt a user with a question in a floating window --- @param headline string --- @param body string[] @@ -832,7 +833,7 @@ local M = {} --- Utility function to prompt a user with a question in a floating window --- @type fun(headline: string, body: string[]): boolean -M.ask_user = awrap(prompt_user, 3) +M.ask_user = awrap(3, prompt_user) --- Open a new display window --- @param cbs? Pckr.Display.Callbacks diff --git a/lua/pckr/jobs.lua b/lua/pckr/jobs.lua index 4b49394..f26079d 100644 --- a/lua/pckr/jobs.lua +++ b/lua/pckr/jobs.lua @@ -11,7 +11,7 @@ local M = {} --- @param opts vim.SystemOpts --- @param callback? fun(_: vim.SystemCompleted) --- @type fun(task: string|string[], opts: vim.SystemOpts): vim.SystemCompleted -M.run = a.wrap(function(task, opts, callback) +M.run = a.wrap(3, function(task, opts, callback) if type(task) == 'string' then local shell = os.getenv('SHELL') or vim.o.shell local minus_c = shell:find('cmd.exe$') and '/c' or '-c' @@ -25,6 +25,6 @@ M.run = a.wrap(function(task, opts, callback) callback(obj) end end) -end, 3) +end) return M diff --git a/lua/pckr/lockfile.lua b/lua/pckr/lockfile.lua index 9434023..982d541 100644 --- a/lua/pckr/lockfile.lua +++ b/lua/pckr/lockfile.lua @@ -60,10 +60,11 @@ local function update(path, info) f:close() end -M.lock = a.sync(function() +--- @async +function M.lock() local lock_tasks = {} --- @type fun()[] for _, plugin in pairs(P.plugins) do - lock_tasks[#lock_tasks + 1] = a.sync(function() + lock_tasks[#lock_tasks + 1] = a.sync(0, function() local plugin_type = plugin_types[plugin.type] if plugin_type.get_rev then return plugin.url, (plugin_type.get_rev(plugin)) @@ -79,16 +80,16 @@ M.lock = a.sync(function() end end - a.main() + a.schedule() local lockfile = config.lockfile.path update(lockfile, info1) log.fmt_info('Lockfile created at %s', config.lockfile.path) -end) +end --- @param plugin Pckr.Plugin --- @param disp Pckr.Display --- @param commit? string -local restore_plugin = a.sync(function(plugin, disp, commit) +local restore_plugin = a.sync(3, function(plugin, disp, commit) disp:task_start(plugin.name, fmt('restoring to %s', commit)) if plugin.type == 'local' then @@ -116,12 +117,13 @@ local restore_plugin = a.sync(function(plugin, disp, commit) end disp:task_succeeded(plugin.name, fmt('restored to commit %s', commit)) -end, 3) +end) --- @class LockInfo --- @field commit string -M.restore = a.sync(function() +--- @async +function M.restore() local disp = assert(display.open({})) disp:update_headline_message('Restoring from lockfile') @@ -136,6 +138,6 @@ M.restore = a.sync(function() end run_tasks(restore_tasks, disp, 'restoring') -end) +end return M diff --git a/lua/pckr/plugin_types.lua b/lua/pckr/plugin_types.lua index 3418586..a5cffcc 100644 --- a/lua/pckr/plugin_types.lua +++ b/lua/pckr/plugin_types.lua @@ -14,11 +14,7 @@ return setmetatable(M, { --- @param k string --- @return Pckr.PluginHandler? __index = function(t, k) - if k == 'git' then - t[k] = require('pckr.plugin_types.git') - elseif k == 'local' then - t[k] = require('pckr.plugin_types.local') - end + t[k] = require('pckr.plugin_types.'..k) return t[k] end, }) diff --git a/lua/pckr/plugin_types/git.lua b/lua/pckr/plugin_types/git.lua index 787c83f..d7a1d44 100644 --- a/lua/pckr/plugin_types/git.lua +++ b/lua/pckr/plugin_types/git.lua @@ -403,7 +403,7 @@ end local function sanitize_path(path) assert(path) --- @diagnostic disable-next-line - local lerr, stat = a.wrap(uv.fs_lstat, 2)(path) + local lerr, stat = a.wait(2, uv.fs_lstat, path) --- @diagnostic disable-next-line if lerr or stat.type ~= 'link' then -- path doesn't exist or isn't a link @@ -411,14 +411,14 @@ local function sanitize_path(path) end -- path is a link; check destination exists, otherwise delete - local err = a.wrap(uv.fs_realpath, 2)(path) + local err = a.wait(2, uv.fs_realpath, path) if not err then -- exists return end -- dead link; remove - a.wrap(uv.fs_unlink, 2)(path) + a.wait(2, uv.fs_unlink, path) end --- @param plugin Pckr.Plugin @@ -446,10 +446,11 @@ local function install(plugin, disp) return true, out end +--- @async --- @param plugin Pckr.Plugin --- @param disp Pckr.Display --- @return string? -M.installer = async(function(plugin, disp) +M.installer = function(plugin, disp) local ok, out = install(plugin, disp) if ok then @@ -460,7 +461,7 @@ M.installer = async(function(plugin, disp) plugin.err = out return out -end, 2) +end --- @param plugin Pckr.Plugin --- @param msg string @@ -548,22 +549,23 @@ local function update(plugin, disp, ff_only) return true, out end +--- @async --- @param plugin Pckr.Plugin --- @param disp Pckr.Display --- @param ff_only? boolean --- @return string? -M.updater = async(function(plugin, disp, ff_only) +M.updater = function(plugin, disp, ff_only) local ok, out = update(plugin, disp, ff_only) if not ok then plugin.err = out return out end plugin.messages = out -end, 3) +end --- @param plugin Pckr.Plugin --- @return string? -M.remote_url = async(function(plugin) +M.remote_url = async(1, function(plugin) local ok, out = git_run({ 'remote', 'get-url', 'origin' }, { cwd = plugin.install_path, }) @@ -571,12 +573,12 @@ M.remote_url = async(function(plugin) if ok then return out[1] end -end, 1) +end) --- @param plugin Pckr.Plugin --- @param commit string --- @param callback fun(_: string?, _: string?) -M.diff = async(function(plugin, commit, callback) +M.diff = async(3, function(plugin, commit, callback) local ok, out = git_run({ 'show', '--no-color', @@ -591,11 +593,11 @@ M.diff = async(function(plugin, commit, callback) else callback(nil, out) end -end, 3) +end) --- @param plugin Pckr.Plugin --- @return string? -M.revert_last = async(function(plugin) +M.revert_last = async(1, function(plugin) local ok, out = git_run({ 'reset', '--hard', 'HEAD@{1}' }, { cwd = plugin.install_path, }) @@ -612,13 +614,13 @@ M.revert_last = async(function(plugin) end log.fmt_info('Reverted update for %s', plugin.name) -end, 1) +end) --- Reset the plugin to `commit` --- @param plugin Pckr.Plugin --- @param commit string --- @return string? -M.revert_to = async(function(plugin, commit) +M.revert_to = async(2, function(plugin, commit) assert(type(commit) == 'string', fmt("commit: string expected but '%s' provided", type(commit))) log.fmt_debug("Reverting '%s' to commit '%s'", plugin.name, commit) local ok, out = git_run({ 'reset', '--hard', commit, '--' }, { @@ -628,13 +630,13 @@ M.revert_to = async(function(plugin, commit) if not ok then return out end -end, 2) +end) --- Returns HEAD's short hash --- @param plugin Pckr.Plugin --- @return string? -M.get_rev = async(function(plugin) +M.get_rev = async(1, function(plugin) return get_head(plugin.install_path) -end, 1) +end) return M diff --git a/lua/pckr/plugin_types/local.lua b/lua/pckr/plugin_types/local.lua index 9630e19..c3ba37c 100644 --- a/lua/pckr/plugin_types/local.lua +++ b/lua/pckr/plugin_types/local.lua @@ -13,16 +13,17 @@ M.installer = function(plugin, _disp) vim.loop.fs_symlink(plugin._dir, plugin.install_path) end +--- @async --- @param plugin Pckr.Plugin --- @param disp Pckr.Display --- @return string? -M.updater = a.sync(function(plugin, disp) +M.updater = function(plugin, disp) local gitdir = util.join_paths(plugin.install_path, '.git') if uv.fs_stat(gitdir) then return require('pckr.plugin_types.git').updater(plugin, disp, true) end -- Nothing to do -end) +end M.revert_to = function(_, _) log.warn("Can't revert a local plugin!") diff --git a/lua/pckr/status.lua b/lua/pckr/status.lua index c563b47..3b55f86 100644 --- a/lua/pckr/status.lua +++ b/lua/pckr/status.lua @@ -132,7 +132,8 @@ local function load_state(plugin) return '' end -M.run = a.sync(function() +--- @async +function M.run() local plugins = require('pckr.plugin').plugins if plugins == nil then log.warn('pckr_plugins table is nil! Cannot run pckr.status()!') @@ -163,6 +164,6 @@ M.run = a.sync(function() disp:update_headline_message( fmt('Total plugins: %d (%.2fms)', vim.tbl_count(plugins), total_time) ) -end) +end return M