Skip to content

Commit

Permalink
Replace RAIICallbackWrapperDestroyer with AsyncCallback (#39952)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #39952

AsyncCallback and SyncCallbacks are better primitives for jsi::Function handling. The code is simpler and requires less manual argument passing. See in D49684248 how the API was extended to support more use-cases.

Changelog: [General] Deprecated RAIICallbackWrapperDestroyer. Use AsyncCallback instead for safe jsi::Function memory ownership.

Differential Revision: D49792717

fbshipit-source-id: ffe9e3de79d3f60d064e758e1bd2eaa0e9aa5547
  • Loading branch information
javache authored and facebook-github-bot committed Oct 9, 2023
1 parent 40f15d4 commit e517072
Show file tree
Hide file tree
Showing 2 changed files with 101 additions and 150 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ jsi::Value createPromiseAsJSIValue(
jsi::Runtime& rt,
PromiseSetupFunctionType&& func);

// Deprecated. Use AsyncCallback instead.
class RAIICallbackWrapperDestroyer {
public:
RAIICallbackWrapperDestroyer(std::weak_ptr<CallbackWrapper> callbackWrapper)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#include <ReactCommon/TurboModulePerfLogger.h>
#include <ReactCommon/TurboModuleUtils.h>
#include <jsi/JSIDynamic.h>
#include <react/bridging/Bridging.h>
#include <react/debug/react_native_assert.h>
#include <react/jni/NativeMap.h>
#include <react/jni/ReadableNativeMap.h>
Expand Down Expand Up @@ -62,61 +63,43 @@ struct JNIArgs {
std::vector<jobject> globalRefs_;
};

jni::local_ref<JCxxCallbackImpl::JavaPart> createJavaCallbackFromJSIFunction(
jsi::Function&& function,
auto createJavaCallback(
jsi::Runtime& rt,
const std::shared_ptr<CallInvoker>& jsInvoker) {
auto weakWrapper =
CallbackWrapper::createWeak(std::move(function), rt, jsInvoker);

// This needs to be a shared_ptr because:
// 1. It cannot be unique_ptr. std::function is copyable but unique_ptr is
// not.
// 2. It cannot be weak_ptr since we need this object to live on.
// 3. It cannot be a value, because that would be deleted as soon as this
// function returns.
auto callbackWrapperOwner =
std::make_shared<RAIICallbackWrapperDestroyer>(weakWrapper);

jsi::Function&& function,
std::shared_ptr<CallInvoker> jsInvoker) {
std::optional<AsyncCallback<>> callback(
{rt, std::move(function), std::move(jsInvoker)});
return JCxxCallbackImpl::newObjectCxxArgs(
[weakWrapper = std::move(weakWrapper),
callbackWrapperOwner = std::move(callbackWrapperOwner),
wrapperWasCalled = false](folly::dynamic responses) mutable {
if (wrapperWasCalled) {
LOG(FATAL) << "callback arg cannot be called more than once";
}

auto strongWrapper = weakWrapper.lock();
if (!strongWrapper) {
[callback = std::move(callback)](folly::dynamic args) mutable {
if (!callback) {
LOG(FATAL) << "Callback arg cannot be called more than once";
return;
}

strongWrapper->jsInvoker().invokeAsync(
[weakWrapper = std::move(weakWrapper),
callbackWrapperOwner = std::move(callbackWrapperOwner),
responses = std::move(responses)]() {
auto strongWrapper2 = weakWrapper.lock();
if (!strongWrapper2) {
return;
}

std::vector<jsi::Value> args;
args.reserve(responses.size());
for (const auto& val : responses) {
args.emplace_back(
jsi::valueFromDynamic(strongWrapper2->runtime(), val));
}

strongWrapper2->callback().call(
strongWrapper2->runtime(),
(const jsi::Value*)args.data(),
args.size());
});

wrapperWasCalled = true;
callback->call([args = std::move(args)](
jsi::Runtime& rt, jsi::Function& jsFunction) {
std::vector<jsi::Value> jsArgs;
jsArgs.reserve(args.size());
for (const auto& val : args) {
jsArgs.emplace_back(jsi::valueFromDynamic(rt, val));
}
jsFunction.call(rt, (const jsi::Value*)jsArgs.data(), jsArgs.size());
});
callback = std::nullopt;
});
}

struct JPromiseImpl : public jni::JavaClass<JPromiseImpl> {
constexpr static auto kJavaDescriptor =
"Lcom/facebook/react/bridge/PromiseImpl;";

static jni::local_ref<javaobject> create(
jni::local_ref<JCallback::javaobject> resolve,
jni::local_ref<JCallback::javaobject> reject) {
return newInstance(resolve, reject);
}
};

// This is used for generating short exception strings.
std::string stringifyJSIValue(const jsi::Value& v, jsi::Runtime* rt = nullptr) {
if (v.isUndefined()) {
Expand Down Expand Up @@ -355,8 +338,7 @@ JNIArgs convertJSIArgsToJNIArgs(
}
jsi::Function fn = arg->getObject(rt).getFunction(rt);
jarg->l = makeGlobalIfNecessary(
createJavaCallbackFromJSIFunction(std::move(fn), rt, jsInvoker)
.release());
createJavaCallback(rt, std::move(fn), jsInvoker).release());
} else if (type == "Lcom/facebook/react/bridge/ReadableArray;") {
if (!(arg->isObject() && arg->getObject(rt).isArray(rt))) {
throw JavaTurboModuleArgumentConversionException(
Expand Down Expand Up @@ -744,16 +726,14 @@ jsi::Value JavaTurboModule::invokeJavaMethod(
instance_ = jni::make_weak(instance_),
moduleNameStr = name_,
methodNameStr,
id = getUniqueId()]() mutable -> void {
id = getUniqueId()]() mutable {
auto instance = instance_.lockLocal();
if (!instance) {
return;
}
/**
* TODO(ramanpreet): Why do we have to require the environment
* again? Why does JNI crash when we use the env from the upper
* scope?
*/

// Require the env from the current scope, which may be
// different from the original invocation's scope
JNIEnv* env = jni::Environment::current();
const char* moduleName = moduleNameStr.c_str();
const char* methodName = methodNameStr.c_str();
Expand All @@ -777,115 +757,85 @@ jsi::Value JavaTurboModule::invokeJavaMethod(
return jsi::Value::undefined();
}
case PromiseKind: {
// We could use AsyncPromise here, but this avoids the overhead of
// the shared_ptr for PromiseHolder
jsi::Function Promise =
runtime.global().getPropertyAsFunction(runtime, "Promise");

jsi::Function promiseConstructorArg = jsi::Function::createFromHostFunction(
// The promise constructor runs its arg immediately, so this is safe
jobject javaPromise;
jsi::Value jsPromise = Promise.callAsConstructor(
runtime,
jsi::PropNameID::forAscii(runtime, "fn"),
2,
[this,
&jargs,
&globalRefs,
argCount,
jsi::Function::createFromHostFunction(
runtime,
jsi::PropNameID::forAscii(runtime, "fn"),
2,
[&](jsi::Runtime& runtime,
const jsi::Value&,
const jsi::Value* args,
size_t argCount) {
if (argCount != 2) {
throw jsi::JSError(runtime, "Incorrect number of arguments");
}

auto resolve = createJavaCallback(
runtime,
args[0].getObject(runtime).getFunction(runtime),
jsInvoker_);
auto reject = createJavaCallback(
runtime,
args[1].getObject(runtime).getFunction(runtime),
jsInvoker_);
javaPromise = JPromiseImpl::create(resolve, reject).release();

return jsi::Value::undefined();
}));

jobject globalPromise = env->NewGlobalRef(javaPromise);
globalRefs.push_back(globalPromise);
env->DeleteLocalRef(javaPromise);
jargs[argCount].l = globalPromise;

const char* moduleName = name_.c_str();
const char* methodName = methodNameStr.c_str();
TMPL::asyncMethodCallArgConversionEnd(moduleName, methodName);

TMPL::asyncMethodCallDispatch(moduleName, methodName);
nativeMethodCallInvoker_->invokeAsync(
methodName,
[jargs,
globalRefs,
methodID,
instance_ = jni::make_weak(instance_),
moduleNameStr = name_,
methodNameStr,
env](
jsi::Runtime& runtime,
const jsi::Value& thisVal,
const jsi::Value* promiseConstructorArgs,
size_t promiseConstructorArgCount) {
if (promiseConstructorArgCount != 2) {
throw std::invalid_argument("Promise fn arg count must be 2");
id = getUniqueId()]() mutable {
auto instance = instance_.lockLocal();
if (!instance) {
return;
}

jsi::Function resolveJSIFn =
promiseConstructorArgs[0].getObject(runtime).getFunction(
runtime);
jsi::Function rejectJSIFn =
promiseConstructorArgs[1].getObject(runtime).getFunction(
runtime);

auto resolve = createJavaCallbackFromJSIFunction(
std::move(resolveJSIFn), runtime, jsInvoker_)
.release();
auto reject = createJavaCallbackFromJSIFunction(
std::move(rejectJSIFn), runtime, jsInvoker_)
.release();

jclass jPromiseImpl =
env->FindClass("com/facebook/react/bridge/PromiseImpl");
jmethodID jPromiseImplConstructor = env->GetMethodID(
jPromiseImpl,
"<init>",
"(Lcom/facebook/react/bridge/Callback;Lcom/facebook/react/bridge/Callback;)V");

jobject promise = env->NewObject(
jPromiseImpl, jPromiseImplConstructor, resolve, reject);

// Require the env from the current scope, which may be
// different from the original invocation's scope
JNIEnv* env = jni::Environment::current();
const char* moduleName = moduleNameStr.c_str();
const char* methodName = methodNameStr.c_str();
TMPL::asyncMethodCallExecutionStart(moduleName, methodName, id);
env->CallVoidMethodA(instance.get(), methodID, jargs.data());
try {
FACEBOOK_JNI_THROW_PENDING_EXCEPTION();
} catch (...) {
TMPL::asyncMethodCallExecutionFail(moduleName, methodName, id);
throw;
}

jobject globalPromise = env->NewGlobalRef(promise);

globalRefs.push_back(globalPromise);
env->DeleteLocalRef(promise);

jargs[argCount].l = globalPromise;
TMPL::asyncMethodCallArgConversionEnd(moduleName, methodName);
TMPL::asyncMethodCallDispatch(moduleName, methodName);

nativeMethodCallInvoker_->invokeAsync(
methodName,
[jargs,
globalRefs,
methodID,
instance_ = jni::make_weak(instance_),
moduleNameStr,
methodNameStr,
id = getUniqueId()]() mutable -> void {
auto instance = instance_.lockLocal();

if (!instance) {
return;
}
/**
* TODO(ramanpreet): Why do we have to require the
* environment again? Why does JNI crash when we use the env
* from the upper scope?
*/
JNIEnv* env = jni::Environment::current();
const char* moduleName = moduleNameStr.c_str();
const char* methodName = methodNameStr.c_str();

TMPL::asyncMethodCallExecutionStart(
moduleName, methodName, id);
env->CallVoidMethodA(instance.get(), methodID, jargs.data());
try {
FACEBOOK_JNI_THROW_PENDING_EXCEPTION();
} catch (...) {
TMPL::asyncMethodCallExecutionFail(
moduleName, methodName, id);
throw;
}

for (auto globalRef : globalRefs) {
env->DeleteGlobalRef(globalRef);
}
TMPL::asyncMethodCallExecutionEnd(moduleName, methodName, id);
});

return jsi::Value::undefined();
for (auto globalRef : globalRefs) {
env->DeleteGlobalRef(globalRef);
}
TMPL::asyncMethodCallExecutionEnd(moduleName, methodName, id);
});

jsi::Value promise =
Promise.callAsConstructor(runtime, promiseConstructorArg);
checkJNIErrorForMethodCall();

TMPL::asyncMethodCallEnd(moduleName, methodName);

return promise;
return jsPromise;
}
default:
throw std::runtime_error(
Expand Down

0 comments on commit e517072

Please sign in to comment.