Skip to content

Commit

Permalink
Rust protobuf: remove the need for a generated placement_new thunk
Browse files Browse the repository at this point in the history
We have been relying on a per-message generated `placement_new` function for
implementing map insertion, but this CL simplifies things by removing that.
Instead, we do a reflective swap if possible, or else fall back on a copy.

This will probably make insertions a bit slower, but I think it may be worth it
because it should make it much simpler to have a blanket implementation for
ProxedInMapValue that works for all map types.

It looks like it should be possible to make this faster in the future by
implementing a bitwise move that will work for any message.

PiperOrigin-RevId: 676495920
  • Loading branch information
acozzette authored and copybara-github committed Sep 19, 2024
1 parent 8681742 commit 5c3d1e8
Show file tree
Hide file tree
Showing 4 changed files with 48 additions and 40 deletions.
1 change: 0 additions & 1 deletion rust/cpp.rs
Original file line number Diff line number Diff line change
Expand Up @@ -812,7 +812,6 @@ macro_rules! impl_map_primitives {
size_info: MapNodeSizeInfo,
key: $cpp_type,
value: RawMessage,
placement_new: unsafe extern "C" fn(*mut c_void, m: RawMessage),
) -> bool;
pub fn $get_thunk(
m: RawMap,
Expand Down
73 changes: 42 additions & 31 deletions rust/cpp_kernel/map.cc
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
#include <utility>

#include "google/protobuf/map.h"
#include "google/protobuf/message.h"
#include "google/protobuf/message_lite.h"
#include "rust/cpp_kernel/strings.h"

Expand Down Expand Up @@ -42,16 +43,27 @@ void DestroyMapNode(internal::UntypedMapBase* m, internal::NodeBase* node,

template <typename Key>
bool Insert(internal::UntypedMapBase* m, internal::MapNodeSizeInfoT size_info,
Key key, MessageLite* value,
void (*placement_new)(void*, MessageLite*)) {
Key key, MessageLite* value) {
internal::NodeBase* node = internal::RustMapHelper::AllocNode(m, size_info);
if constexpr (std::is_same<Key, PtrAndLen>::value) {
new (node->GetVoidKey()) std::string(key.ptr, key.len);
} else {
*static_cast<Key*>(node->GetVoidKey()) = key;
}
void* value_ptr = node->GetVoidValue(size_info);
placement_new(value_ptr, value);

MessageLite* new_msg = internal::RustMapHelper::PlacementNew(
value, node->GetVoidValue(size_info));
auto* full_msg = DynamicCastMessage<Message>(new_msg);

// If we are working with a full (non-lite) proto, we reflectively swap the
// value into place. Otherwise, we have to perform a copy.
if (full_msg != nullptr) {
full_msg->GetReflection()->Swap(full_msg,
DynamicCastMessage<Message>(value));
} else {
new_msg->CheckTypeAndMergeFrom(*value);
}

node = internal::RustMapHelper::InsertOrReplaceNode(
static_cast<KeyMap<Key>*>(m), node);
if (node == nullptr) {
Expand Down Expand Up @@ -166,33 +178,32 @@ google::protobuf::internal::UntypedMapIterator proto2_rust_map_iter(
return m->begin();
}

#define DEFINE_KEY_SPECIFIC_MAP_OPERATIONS(cpp_type, suffix) \
bool proto2_rust_map_insert_##suffix( \
google::protobuf::internal::UntypedMapBase* m, \
google::protobuf::internal::MapNodeSizeInfoT size_info, cpp_type key, \
google::protobuf::MessageLite* value, \
void (*placement_new)(void*, google::protobuf::MessageLite*)) { \
return google::protobuf::rust::Insert(m, size_info, key, value, placement_new); \
} \
\
bool proto2_rust_map_get_##suffix( \
google::protobuf::internal::UntypedMapBase* m, \
google::protobuf::internal::MapNodeSizeInfoT size_info, cpp_type key, \
google::protobuf::MessageLite** value) { \
return google::protobuf::rust::Get(m, size_info, key, value); \
} \
\
bool proto2_rust_map_remove_##suffix( \
google::protobuf::internal::UntypedMapBase* m, \
google::protobuf::internal::MapNodeSizeInfoT size_info, cpp_type key) { \
return google::protobuf::rust::Remove(m, size_info, key); \
} \
\
void proto2_rust_map_iter_get_##suffix( \
const google::protobuf::internal::UntypedMapIterator* iter, \
google::protobuf::internal::MapNodeSizeInfoT size_info, cpp_type* key, \
google::protobuf::MessageLite** value) { \
return google::protobuf::rust::IterGet(iter, size_info, key, value); \
#define DEFINE_KEY_SPECIFIC_MAP_OPERATIONS(cpp_type, suffix) \
bool proto2_rust_map_insert_##suffix( \
google::protobuf::internal::UntypedMapBase* m, \
google::protobuf::internal::MapNodeSizeInfoT size_info, cpp_type key, \
google::protobuf::MessageLite* value) { \
return google::protobuf::rust::Insert(m, size_info, key, value); \
} \
\
bool proto2_rust_map_get_##suffix( \
google::protobuf::internal::UntypedMapBase* m, \
google::protobuf::internal::MapNodeSizeInfoT size_info, cpp_type key, \
google::protobuf::MessageLite** value) { \
return google::protobuf::rust::Get(m, size_info, key, value); \
} \
\
bool proto2_rust_map_remove_##suffix( \
google::protobuf::internal::UntypedMapBase* m, \
google::protobuf::internal::MapNodeSizeInfoT size_info, cpp_type key) { \
return google::protobuf::rust::Remove(m, size_info, key); \
} \
\
void proto2_rust_map_iter_get_##suffix( \
const google::protobuf::internal::UntypedMapIterator* iter, \
google::protobuf::internal::MapNodeSizeInfoT size_info, cpp_type* key, \
google::protobuf::MessageLite** value) { \
return google::protobuf::rust::IterGet(iter, size_info, key, value); \
}

DEFINE_KEY_SPECIFIC_MAP_OPERATIONS(int32_t, i32)
Expand Down
9 changes: 1 addition & 8 deletions src/google/protobuf/compiler/rust/message.cc
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,6 @@ void CppMessageExterns(Context& ctx, const Descriptor& msg) {
ABSL_CHECK(ctx.is_cpp());
ctx.Emit(
{{"new_thunk", ThunkName(ctx, msg, "new")},
{"placement_new_thunk", ThunkName(ctx, msg, "placement_new")},
{"repeated_new_thunk", ThunkName(ctx, msg, "repeated_new")},
{"repeated_free_thunk", ThunkName(ctx, msg, "repeated_free")},
{"repeated_len_thunk", ThunkName(ctx, msg, "repeated_len")},
Expand All @@ -208,7 +207,6 @@ void CppMessageExterns(Context& ctx, const Descriptor& msg) {
{"map_size_info_thunk", ThunkName(ctx, msg, "size_info")}},
R"rs(
fn $new_thunk$() -> $pbr$::RawMessage;
fn $placement_new_thunk$(ptr: *mut $std$::ffi::c_void, m: $pbr$::RawMessage);
fn $repeated_new_thunk$() -> $pbr$::RawRepeatedField;
fn $repeated_free_thunk$(raw: $pbr$::RawRepeatedField);
fn $repeated_len_thunk$(raw: $pbr$::RawRepeatedField) -> usize;
Expand Down Expand Up @@ -566,7 +564,6 @@ void MessageProxiedInMapValue(Context& ctx, const Descriptor& msg) {
for (const auto& t : kMapKeyTypes) {
ctx.Emit(
{{"map_size_info_thunk", ThunkName(ctx, msg, "size_info")},
{"placement_new_thunk", ThunkName(ctx, msg, "placement_new")},
{"map_insert",
absl::StrCat("proto2_rust_map_insert_", t.thunk_ident)},
{"map_remove",
Expand Down Expand Up @@ -616,7 +613,7 @@ void MessageProxiedInMapValue(Context& ctx, const Descriptor& msg) {
map.as_raw($pbi$::Private),
$map_size_info_thunk$($key_t$::SIZE_INFO_INDEX),
$key_expr$,
value.into_proxied($pbi$::Private).raw_msg(), $placement_new_thunk$)
value.into_proxied($pbi$::Private).raw_msg())
}
}
Expand Down Expand Up @@ -1355,7 +1352,6 @@ void GenerateThunksCc(Context& ctx, const Descriptor& msg) {
{"Msg", RsSafeName(msg.name())},
{"QualifiedMsg", cpp::QualifiedClassName(&msg)},
{"new_thunk", ThunkName(ctx, msg, "new")},
{"placement_new_thunk", ThunkName(ctx, msg, "placement_new")},
{"repeated_new_thunk", ThunkName(ctx, msg, "repeated_new")},
{"repeated_free_thunk", ThunkName(ctx, msg, "repeated_free")},
{"repeated_len_thunk", ThunkName(ctx, msg, "repeated_len")},
Expand Down Expand Up @@ -1391,9 +1387,6 @@ void GenerateThunksCc(Context& ctx, const Descriptor& msg) {
// clang-format off
extern $abi$ {
void* $new_thunk$() { return new $QualifiedMsg$(); }
void $placement_new_thunk$(void* ptr, $QualifiedMsg$& m) {
new (ptr) $QualifiedMsg$(std::move(m));
}
void* $repeated_new_thunk$() {
return new google::protobuf::RepeatedPtrField<$QualifiedMsg$>();
Expand Down
5 changes: 5 additions & 0 deletions src/google/protobuf/map.h
Original file line number Diff line number Diff line change
Expand Up @@ -1190,6 +1190,11 @@ class RustMapHelper {
m->erase_no_destroy(bucket, static_cast<typename Map::KeyNode*>(node));
}

static google::protobuf::MessageLite* PlacementNew(const MessageLite* prototype,
void* mem) {
return prototype->GetClassData()->PlacementNew(mem, /* arena = */ nullptr);
}

static void DestroyMessage(MessageLite* m) { m->DestroyInstance(); }

static void ClearTable(UntypedMapBase* m, ClearInput input) {
Expand Down

0 comments on commit 5c3d1e8

Please sign in to comment.