Skip to content

Commit

Permalink
refactor: update async module
Browse files Browse the repository at this point in the history
  • Loading branch information
lewis6991 committed Jul 13, 2024
1 parent 8f6ab24 commit 07630d1
Show file tree
Hide file tree
Showing 9 changed files with 131 additions and 115 deletions.
63 changes: 32 additions & 31 deletions lua/pckr/actions.lua
Original file line number Diff line number Diff line change
Expand Up @@ -133,12 +133,13 @@ local function update_helptags(results)
end
end

--- @async
--- @param plugin Pckr.Plugin
--- @param disp Pckr.Display
--- @return string?
local post_update_hook = a.sync(function(plugin, disp)
local function post_update_hook(plugin, disp)
if plugin.run or plugin.start then
a.main()
a.schedule()
local loader = require('pckr.loader')
loader.load_plugin(plugin)
end
Expand All @@ -147,7 +148,7 @@ local post_update_hook = a.sync(function(plugin, disp)
return
end

a.main()
a.schedule()

local run_task = plugin.run

Expand Down Expand Up @@ -176,14 +177,14 @@ local post_update_hook = a.sync(function(plugin, disp)
end
end
end
end, 2)
end

--- @alias Pckr.Task fun(plugin: Pckr.Plugin, disp: Pckr.Display, cb: fun()): string?, string?

--- @param plugin Pckr.Plugin
--- @param disp Pckr.Display
--- @return string, string?
local install_task = a.sync(function(plugin, disp)
local install_task = a.sync(2, function(plugin, disp)
disp:task_start(plugin.name, 'installing...')

local plugin_type = require('pckr.plugin_types')[plugin.type]
Expand All @@ -209,13 +210,13 @@ local install_task = a.sync(function(plugin, disp)
end

return plugin.name, err
end, 2)
end)

--- @param plugin Pckr.Plugin
--- @param disp Pckr.Display
--- @param __cb? function
--- @return string?, string?
local update_task = a.sync(function(plugin, disp, __cb)
local update_task = a.sync(2, function(plugin, disp, __cb)
disp:task_start(plugin.name, 'updating...')

if plugin.lock then
Expand Down Expand Up @@ -263,10 +264,10 @@ local update_task = a.sync(function(plugin, disp, __cb)
end

return plugin.name, plugin.err
end, 2)
end)

--- Find and remove any plugins not currently configured for use
local do_clean = a.sync(function()
local do_clean = a.sync(0, function()
log.debug('Starting clean')

local to_remove = fsstate.find_extra_plugins(pckr_plugins)
Expand All @@ -278,7 +279,7 @@ local do_clean = a.sync(function()
return
end

a.main()
a.schedule()

local lines = {} --- @type string[]
for path, _ in pairs(to_remove) do
Expand All @@ -299,7 +300,7 @@ local do_clean = a.sync(function()
log.fmt_warn('Could not remove %s', path)
end
end
end, 0)
end)

--- @return Pckr.Plugin
local function get_pckr_spec()
Expand Down Expand Up @@ -334,29 +335,29 @@ local function sync(op, plugins)
local missing = fsstate.find_missing_plugins(pckr_plugins)
local missing_plugins, installed_plugins = util.partition(missing, plugins)

a.main()
a.schedule()

local disp = open_display()

local delta = util.measure(function()
if install then
log.debug('Gathering install tasks')
local results = map_task(install_task, missing_plugins, disp, 'installing')
a.main()
a.schedule()
update_helptags(results)
end

if update then
log.debug('Gathering update tasks')
local results = map_task(update_task, installed_plugins, disp, 'updating')
a.main()
a.schedule()
update_helptags(results)
end

if upgrade then
pckr_plugins['pckr.nvim'] = get_pckr_spec()
local results = map_task(update_task, { 'pckr.nvim' }, disp, 'updating')
a.main()
a.schedule()
update_helptags(results)
end
end)
Expand All @@ -369,9 +370,9 @@ end
--- @param plugins? string[]
--- @param _opts table?
--- @param __cb? function
M.install = a.sync(function(plugins, _opts, __cb)
M.install = a.sync(2, function(plugins, _opts, __cb)
sync('install', plugins)
end, 2)
end)

--- Update operation:
--- Takes an optional list of plugin names as an argument. If no list is given,
Expand All @@ -380,9 +381,9 @@ end, 2)
--- @param plugins? string[] List of plugin names to update.
--- @param _opts table?
--- @param __cb? function
M.update = a.sync(function(plugins, _opts, __cb)
M.update = a.sync(2, function(plugins, _opts, __cb)
sync('update', plugins)
end, 2)
end)

--- Sync operation:
--- Takes an optional list of plugin names as an argument. If no list is given,
Expand All @@ -391,43 +392,43 @@ end, 2)
--- @param plugins? string[]
--- @param _opts table?
--- @param __cb? function
M.sync = a.sync(function(plugins, _opts, __cb)
M.sync = a.sync(2, function(plugins, _opts, __cb)
sync('sync', plugins)
end, 2)
end)

M.upgrade = a.sync(function(_, _opts, __cb)
M.upgrade = a.sync(2, function(_, _opts, __cb)
sync('upgrade')
end, 2)
end)

--- @param _ any
--- @param _opts table?
--- @param __cb? function
M.status = a.sync(function(_, _opts, __cb)
M.status = a.sync(2, function(_, _opts, __cb)
require('pckr.status').run()
end, 2)
end)

--- Clean operation:
--- Finds plugins present in the `pckr` package but not in the managed set
--- @param _ any
--- @param _opts table?
--- @param __cb? function
M.clean = a.sync(function(_, _opts, __cb)
M.clean = a.sync(2, function(_, _opts, __cb)
do_clean()
end, 2)
end)

--- @param _ any
--- @param _opts table?
--- @param __cb? function
M.lock = a.sync(function(_, _opts, __cb)
M.lock = a.sync(2, function(_, _opts, __cb)
require('pckr.lockfile').lock()
end, 2)
end)

--- @param _ any
--- @param _opts table?
--- @param __cb? function
M.restore = a.sync(function(_, _opts, __cb)
M.restore = a.sync(2 ,function(_, _opts, __cb)
require('pckr.lockfile').restore()
end, 2)
end)

--- @param _ any
--- @param _opts table?
Expand Down
104 changes: 58 additions & 46 deletions lua/pckr/async.lua
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
local co = coroutine

local function validate_callback(func, callback)
if callback and type(callback) ~= 'function' then
local info = debug.getinfo(func, 'nS')
Expand All @@ -17,85 +15,99 @@ end
--- @param func function
--- @param callback function?
--- @param ... any
local function execute(func, callback, ...)
local function run(func, callback, ...)
validate_callback(func, callback)

local thread = co.create(func)
local co = coroutine.create(func)

local function step(...)
local ret = { co.resume(thread, ...) }
--- @type boolean, integer, function
local stat, nargs, fn_or_ret = unpack(ret)
local ret = { coroutine.resume(co, ...) }
local stat = ret[1]

if not stat then
local err = ret[2] --[[@as string]]
error(
string.format(
'The coroutine failed with this message: %s\n%s',
nargs,
debug.traceback(thread)
)
string.format('The coroutine failed with this message: %s\n%s', err, debug.traceback(co))
)
end

if co.status(thread) == 'dead' then
if coroutine.status(co) == 'dead' then
if callback then
callback(unpack(ret, 2))
callback(unpack(ret, 2, table.maxn(ret)))
end
return
end

local args = { select(4, unpack(ret)) }
--- @type integer, fun(...: any): any
local nargs, fn = ret[2], ret[3]

assert(type(fn) == 'function', 'type error :: expected func')

--- @type any[]
local args = { unpack(ret, 4, table.maxn(ret)) }
args[nargs] = step
fn_or_ret(unpack(args, 1, nargs))
fn(unpack(args, 1, nargs))
end

step(...)
end

local M = {}

--- @param argc integer
--- @param func function
--- @param ... any
--- @return any ...
function M.wait(argc, func, ...)
-- Always run the wrapped functions in xpcall and re-raise the error in the
-- coroutine. This makes pcall work as normal.
local function pfunc(...)
local args = { ... } --- @type any[]
local cb = args[argc]
args[argc] = function(...)
cb(true, ...)
end
xpcall(func, function(err)
cb(false, err, debug.traceback())
end, unpack(args, 1, argc))
end

local ret = { coroutine.yield(argc, pfunc, ...) }

local ok = ret[1]
if not ok then
--- @type string, string
local err, traceback = ret[2], ret[3]
error(string.format('Wrapped function failed: %s\n%s', err, traceback))
end

return unpack(ret, 2, table.maxn(ret))
end

--- Creates an async function with a callback style function.
--- @generic F: function
--- @param func F
--- @param argc integer
--- @return F
function M.wrap(func, argc)
--- @param func function
--- @return function
function M.wrap(argc, func)
assert(type(argc) == 'number')
assert(type(func) == 'function')
return function(...)
if not co.running() or select('#', ...) == argc then
return func(...)
end
return co.yield(argc, func, ...)
return M.wait(argc, func, ...)
end
end

---Use this to create a function which executes in an async context but
---called from a non-async context. Inherently this cannot return anything
---since it is non-blocking
--- @generic F: function
--- @param nargs integer
--- @param func async F
--- @param nargs? integer
--- @return F
function M.sync(func, nargs)
nargs = nargs or 0
function M.sync(nargs, func)
return function(...)
if co.running() then
return func(...)
end
assert(not coroutine.running())
local callback = select(nargs + 1, ...)
execute(func, callback, unpack({ ... }, 1, nargs))
end
end

--- For functions that don't provide a callback as there last argument
--- @generic F: function
--- @param func F
--- @return F
function M.void(func)
return function(...)
if co.running() then
return func(...)
end
execute(func, nil, ...)
run(func, callback, unpack({ ... }, 1, nargs))
end
end

Expand All @@ -105,7 +117,7 @@ end
--- @param thunks (fun(cb: function): R)[]
--- @return {[1]: R}[]
function M.join(n, interrupt_check, thunks)
return co.yield(1, function(finish)
return coroutine.yield(1, function(finish)
if #thunks == 0 then
return finish()
end
Expand Down Expand Up @@ -154,6 +166,6 @@ end
---An async function that when called will yield to the Neovim scheduler to be
---able to call the API.
--- @type fun()
M.main = M.wrap(vim.schedule, 1)
M.schedule = M.wrap(1, vim.schedule)

return M
5 changes: 3 additions & 2 deletions lua/pckr/display.lua
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ local function get_plugin(disp)
end

--- @param inner? boolean
--- @return vim.api.keyset.float_config
--- @return vim.api.keyset.win_config
local function get_win_config(inner)
local vpad = inner and 8 or 6
local hpad = inner and 14 or 10
Expand Down Expand Up @@ -392,6 +392,7 @@ local function toggle_info(disp)
api.nvim_win_set_cursor(disp.win, cursor_pos)
end

--- @async
--- Utility function to prompt a user with a question in a floating window
--- @param headline string
--- @param body string[]
Expand Down Expand Up @@ -832,7 +833,7 @@ local M = {}

--- Utility function to prompt a user with a question in a floating window
--- @type fun(headline: string, body: string[]): boolean
M.ask_user = awrap(prompt_user, 3)
M.ask_user = awrap(3, prompt_user)

--- Open a new display window
--- @param cbs? Pckr.Display.Callbacks
Expand Down
Loading

0 comments on commit 07630d1

Please sign in to comment.