diff --git a/lib/async_hooks.js b/lib/async_hooks.js index d5d5407b1de19b..db3ebec83d85d3 100644 --- a/lib/async_hooks.js +++ b/lib/async_hooks.js @@ -131,6 +131,18 @@ class AsyncHook { hook_fields[kAfter] -= +!!this[after_symbol]; hook_fields[kDestroy] -= +!!this[destroy_symbol]; hooks_array.splice(index, 1); + + if (hooks_array.length === 0) { + if (hook_fields[kInit] + hook_fields[kBefore] + + hook_fields[kAfter] + hook_fields[kDestroy] !== 0) { + const { inspect } = require('util'); + throw new Error(`Invalid async_hooks state: ${inspect(hook_fields)}`); + } + + async_wrap.tearDownHooks(); + setupHooksCalled = false; + } + return this; } } diff --git a/src/async-wrap.cc b/src/async-wrap.cc index 0ea6c64a8be582..68a653d46329c0 100644 --- a/src/async-wrap.cc +++ b/src/async-wrap.cc @@ -291,8 +291,7 @@ class PromiseWrap : public AsyncWrap { static void PromiseHook(PromiseHookType type, Local promise, Local parent, void* arg) { - Local context = promise->CreationContext(); - Environment* env = Environment::GetCurrent(context); + Environment* env = static_cast(arg); PromiseWrap* wrap = Unwrap(promise); if (type == PromiseHookType::kInit || wrap == nullptr) { bool silent = type != PromiseHookType::kInit; @@ -334,8 +333,24 @@ static void SetupHooks(const FunctionCallbackInfo& args) { SET_HOOK_FN(before); SET_HOOK_FN(after); SET_HOOK_FN(destroy); - env->AddPromiseHook(PromiseHook, nullptr); #undef SET_HOOK_FN + + env->AddPromiseHook(PromiseHook, static_cast(env)); +} + + +static void TearDownHooks(const FunctionCallbackInfo& args) { + Environment* env = Environment::GetCurrent(args); + + CHECK(!env->async_hooks_init_function().IsEmpty()); + + env->set_async_hooks_init_function(Local()); + env->set_async_hooks_before_function(Local()); + env->set_async_hooks_after_function(Local()); + env->set_async_hooks_destroy_function(Local()); + + bool removed = env->RemovePromiseHook(PromiseHook, static_cast(env)); + CHECK(removed); } @@ -391,6 +406,7 @@ void AsyncWrap::Initialize(Local target, HandleScope scope(isolate); env->SetMethod(target, "setupHooks", SetupHooks); + env->SetMethod(target, "tearDownHooks", TearDownHooks); env->SetMethod(target, "pushAsyncIds", PushAsyncIds); env->SetMethod(target, "popAsyncIds", PopAsyncIds); env->SetMethod(target, "clearIdStack", ClearIdStack); diff --git a/src/env.cc b/src/env.cc index 034625b375446c..8c90f67512a1f9 100644 --- a/src/env.cc +++ b/src/env.cc @@ -11,6 +11,7 @@ #endif #include +#include namespace node { @@ -195,6 +196,23 @@ void Environment::AddPromiseHook(promise_hook_func fn, void* arg) { } } +bool Environment::RemovePromiseHook(promise_hook_func fn, void* arg) { + auto it = std::find_if( + promise_hooks_.begin(), promise_hooks_.end(), + [&](const PromiseHookCallback& hook) { + return hook.cb_ == fn && hook.arg_ == arg; + }); + + if (it == promise_hooks_.end()) return false; + + promise_hooks_.erase(it); + if (promise_hooks_.empty()) { + isolate_->SetPromiseHook(nullptr); + } + + return true; +} + void Environment::EnvPromiseHook(v8::PromiseHookType type, v8::Local promise, v8::Local parent) { diff --git a/src/env.h b/src/env.h index c8c8232cc07fd4..cd7222b9d05365 100644 --- a/src/env.h +++ b/src/env.h @@ -653,6 +653,7 @@ class Environment { static const int kContextEmbedderDataIndex = NODE_CONTEXT_EMBEDDER_DATA_INDEX; void AddPromiseHook(promise_hook_func fn, void* arg); + bool RemovePromiseHook(promise_hook_func fn, void* arg); private: inline void ThrowError(v8::Local (*fun)(v8::Local),