diff --git a/src/async-wrap.cc b/src/async-wrap.cc index 4d357f2f41b7ef..7cfa9e0fe817aa 100644 --- a/src/async-wrap.cc +++ b/src/async-wrap.cc @@ -333,12 +333,7 @@ void PromiseWrap::GetParentId(Local property, static void PromiseHook(PromiseHookType type, Local promise, Local parent, void* arg) { - Local context = promise->CreationContext(); - Environment* env = Environment::GetCurrent(context); - - // PromiseHook() should never be called if no hooks have been enabled. - CHECK_GT(env->async_hooks()->fields()[AsyncHooks::kTotals], 0); - + Environment* env = static_cast(arg); Local resource_object_value = promise->GetInternalField(0); PromiseWrap* wrap = nullptr; if (resource_object_value->IsObject()) { @@ -376,9 +371,18 @@ static void PromiseHook(PromiseHookType type, Local promise, CHECK_NE(wrap, nullptr); if (type == PromiseHookType::kBefore) { + env->async_hooks()->push_ids(wrap->get_id(), wrap->get_trigger_id()); PreCallbackExecution(wrap, false); } else if (type == PromiseHookType::kAfter) { PostCallbackExecution(wrap, false); + if (env->current_async_id() == wrap->get_id()) { + // This condition might not be true if async_hooks was enabled during + // the promise callback execution. + // Popping it off the stack can be skipped in that case, because is is + // known that it would correspond to exactly one call with + // PromiseHookType::kBefore that was not witnessed by the PromiseHook. + env->async_hooks()->pop_ids(wrap->get_id()); + } } } @@ -429,13 +433,19 @@ static void SetupHooks(const FunctionCallbackInfo& args) { static void EnablePromiseHook(const FunctionCallbackInfo& args) { Environment* env = Environment::GetCurrent(args); - env->AddPromiseHook(PromiseHook, nullptr); + env->AddPromiseHook(PromiseHook, static_cast(env)); } static void DisablePromiseHook(const FunctionCallbackInfo& args) { Environment* env = Environment::GetCurrent(args); - env->RemovePromiseHook(PromiseHook, nullptr); + + // Delay the call to `RemovePromiseHook` because we might currently be + // between the `before` and `after` calls of a Promise. + env->isolate()->EnqueueMicrotask([](void* data) { + Environment* env = static_cast(data); + env->RemovePromiseHook(PromiseHook, data); + }, static_cast(env)); } diff --git a/src/env.cc b/src/env.cc index c97868a5882b41..0087f719dc00db 100644 --- a/src/env.cc +++ b/src/env.cc @@ -184,8 +184,11 @@ void Environment::AddPromiseHook(promise_hook_func fn, void* arg) { [&](const PromiseHookCallback& hook) { return hook.cb_ == fn && hook.arg_ == arg; }); - CHECK_EQ(it, promise_hooks_.end()); - promise_hooks_.push_back(PromiseHookCallback{fn, arg}); + if (it != promise_hooks_.end()) { + it->enable_count_++; + return; + } + promise_hooks_.push_back(PromiseHookCallback{fn, arg, 1}); if (promise_hooks_.size() == 1) { isolate_->SetPromiseHook(EnvPromiseHook); @@ -201,6 +204,8 @@ bool Environment::RemovePromiseHook(promise_hook_func fn, void* arg) { if (it == promise_hooks_.end()) return false; + if (--it->enable_count_ > 0) return true; + promise_hooks_.erase(it); if (promise_hooks_.empty()) { isolate_->SetPromiseHook(nullptr); diff --git a/src/env.h b/src/env.h index 458173ec6d2a41..527618d7c65409 100644 --- a/src/env.h +++ b/src/env.h @@ -709,6 +709,7 @@ class Environment { struct PromiseHookCallback { promise_hook_func cb_; void* arg_; + size_t enable_count_; }; std::vector promise_hooks_; diff --git a/test/addons/async-hooks-promise/test.js b/test/addons/async-hooks-promise/test.js index bbe11dd3c57d01..b0af8806bd665f 100644 --- a/test/addons/async-hooks-promise/test.js +++ b/test/addons/async-hooks-promise/test.js @@ -36,8 +36,12 @@ assert.strictEqual( hook1.disable(); -// Check that internal fields are no longer being set. -assert.strictEqual( - binding.getPromiseField(Promise.resolve(1)), - 0, - 'Promise internal field used despite missing enabled AsyncHook'); +// Check that internal fields are no longer being set. This needs to be delayed +// a bit because the `disable()` call only schedules disabling the hook in a +// future microtask. +setImmediate(() => { + assert.strictEqual( + binding.getPromiseField(Promise.resolve(1)), + 0, + 'Promise internal field used despite missing enabled AsyncHook'); +}); diff --git a/test/parallel/test-async-hooks-disable-during-promise.js b/test/parallel/test-async-hooks-disable-during-promise.js new file mode 100644 index 00000000000000..a81c4fbc40caf3 --- /dev/null +++ b/test/parallel/test-async-hooks-disable-during-promise.js @@ -0,0 +1,17 @@ +'use strict'; +const common = require('../common'); +const async_hooks = require('async_hooks'); + +const hook = async_hooks.createHook({ + init: common.mustCall(2), + before: common.mustCall(1), + after: common.mustNotCall() +}).enable(); + +Promise.resolve(1).then(common.mustCall(() => { + hook.disable(); + + Promise.resolve(42).then(common.mustCall()); + + process.nextTick(common.mustCall()); +})); diff --git a/test/parallel/test-async-hooks-enable-during-promise.js b/test/parallel/test-async-hooks-enable-during-promise.js new file mode 100644 index 00000000000000..106f433322ef94 --- /dev/null +++ b/test/parallel/test-async-hooks-enable-during-promise.js @@ -0,0 +1,13 @@ +'use strict'; +const common = require('../common'); +const async_hooks = require('async_hooks'); + +Promise.resolve(1).then(common.mustCall(() => { + async_hooks.createHook({ + init: common.mustCall(), + before: common.mustCall(), + after: common.mustCall(2) + }).enable(); + + process.nextTick(common.mustCall()); +})); diff --git a/test/parallel/test-async-hooks-promise-triggerid.js b/test/parallel/test-async-hooks-promise-triggerid.js new file mode 100644 index 00000000000000..7afd005855fb3e --- /dev/null +++ b/test/parallel/test-async-hooks-promise-triggerid.js @@ -0,0 +1,32 @@ +'use strict'; +const common = require('../common'); +const assert = require('assert'); +const async_hooks = require('async_hooks'); + +common.crashOnUnhandledRejection(); + +const promiseAsyncIds = []; + +async_hooks.createHook({ + init: common.mustCallAtLeast((id, type, triggerId) => { + if (type === 'PROMISE') { + // Check that the last known Promise is triggering the creation of + // this one. + assert.strictEqual(promiseAsyncIds[promiseAsyncIds.length - 1] || 1, + triggerId); + promiseAsyncIds.push(id); + } + }, 3), + before: common.mustCall((id) => { + assert.strictEqual(id, promiseAsyncIds[1]); + }), + after: common.mustCall((id) => { + assert.strictEqual(id, promiseAsyncIds[1]); + }) +}).enable(); + +Promise.resolve(42).then(common.mustCall(() => { + assert.strictEqual(async_hooks.executionAsyncId(), promiseAsyncIds[1]); + assert.strictEqual(async_hooks.triggerAsyncId(), promiseAsyncIds[0]); + Promise.resolve(10); +}));