Skip to content

Commit

Permalink
Remove runtime code paths for field stripping.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 500274174
  • Loading branch information
ckennelly authored and copybara-github committed Jan 9, 2023
1 parent 37a4656 commit 2d6ee17
Show file tree
Hide file tree
Showing 5 changed files with 8 additions and 85 deletions.
54 changes: 3 additions & 51 deletions src/google/protobuf/generated_message_reflection.cc
Original file line number Diff line number Diff line change
Expand Up @@ -240,12 +240,6 @@ static void ReportReflectionUsageEnumTypeError(
<< value->full_name();
}

inline void CheckInvalidAccess(const internal::ReflectionSchema& schema,
const FieldDescriptor* field) {
GOOGLE_ABSL_CHECK(!schema.IsFieldStripped(field))
<< "invalid access to a stripped field " << field->full_name();
}

#define USAGE_CHECK(CONDITION, METHOD, ERROR_DESCRIPTION) \
if (!(CONDITION)) \
ReportReflectionUsageError(descriptor_, field, #METHOD, ERROR_DESCRIPTION)
Expand Down Expand Up @@ -1081,7 +1075,6 @@ void Reflection::SwapFieldsImpl(
const Message* prototype =
message_factory_->GetPrototype(message1->GetDescriptor());
for (const auto* field : fields) {
CheckInvalidAccess(schema_, field);
if (field->is_extension()) {
if (unsafe_shallow_swap) {
MutableExtensionSet(message1)->UnsafeShallowSwapExtension(
Expand Down Expand Up @@ -1152,7 +1145,6 @@ bool Reflection::HasField(const Message& message,
const FieldDescriptor* field) const {
USAGE_CHECK_MESSAGE_TYPE(HasField);
USAGE_CHECK_SINGULAR(HasField);
CheckInvalidAccess(schema_, field);
if (field->is_extension()) {
return GetExtensionSet(message).Has(field->number());
Expand All @@ -1178,7 +1170,6 @@ void Reflection::InternalSwap(Message* lhs, Message* rhs) const {
for (int i = 0; i <= last_non_weak_field_index_; i++) {
const FieldDescriptor* field = descriptor_->field(i);
if (schema_.InRealOneof(field)) continue;
if (schema_.IsFieldStripped(field)) continue;
if (schema_.IsSplit(field)) {
continue;
}
Expand Down Expand Up @@ -1252,7 +1243,6 @@ int Reflection::FieldSize(const Message& message,
const FieldDescriptor* field) const {
USAGE_CHECK_MESSAGE_TYPE(FieldSize);
USAGE_CHECK_REPEATED(FieldSize);
CheckInvalidAccess(schema_, field);
if (field->is_extension()) {
return GetExtensionSet(message).ExtensionSize(field->number());
Expand Down Expand Up @@ -1297,7 +1287,6 @@ int Reflection::FieldSize(const Message& message,
void Reflection::ClearField(Message* message,
const FieldDescriptor* field) const {
USAGE_CHECK_MESSAGE_TYPE(ClearField);
CheckInvalidAccess(schema_, field);
if (field->is_extension()) {
MutableExtensionSet(message)->ClearExtension(field->number());
Expand Down Expand Up @@ -1408,7 +1397,6 @@ void Reflection::RemoveLast(Message* message,
const FieldDescriptor* field) const {
USAGE_CHECK_MESSAGE_TYPE(RemoveLast);
USAGE_CHECK_REPEATED(RemoveLast);
CheckInvalidAccess(schema_, field);

if (field->is_extension()) {
MutableExtensionSet(message)->RemoveLast(field->number());
Expand Down Expand Up @@ -1456,7 +1444,6 @@ void Reflection::RemoveLast(Message* message,
Message* Reflection::ReleaseLast(Message* message,
const FieldDescriptor* field) const {
USAGE_CHECK_ALL(ReleaseLast, REPEATED, MESSAGE);
CheckInvalidAccess(schema_, field);

Message* released;
if (field->is_extension()) {
Expand All @@ -1482,7 +1469,6 @@ Message* Reflection::ReleaseLast(Message* message,
Message* Reflection::UnsafeArenaReleaseLast(
Message* message, const FieldDescriptor* field) const {
USAGE_CHECK_ALL(UnsafeArenaReleaseLast, REPEATED, MESSAGE);
CheckInvalidAccess(schema_, field);

if (field->is_extension()) {
return static_cast<Message*>(
Expand All @@ -1503,7 +1489,6 @@ void Reflection::SwapElements(Message* message, const FieldDescriptor* field,
int index1, int index2) const {
USAGE_CHECK_MESSAGE_TYPE(Swap);
USAGE_CHECK_REPEATED(Swap);
CheckInvalidAccess(schema_, field);

if (field->is_extension()) {
MutableExtensionSet(message)->SwapElements(field->number(), index1, index2);
Expand Down Expand Up @@ -1575,9 +1560,8 @@ bool CreateUnknownEnumValues(const FieldDescriptor* field) {
} // namespace internal
using internal::CreateUnknownEnumValues;

void Reflection::ListFieldsMayFailOnStripped(
const Message& message, bool should_fail,
std::vector<const FieldDescriptor*>* output) const {
void Reflection::ListFields(const Message& message,
std::vector<const FieldDescriptor*>* output) const {
output->clear();

// Optimization: The default instance never has any fields set.
Expand All @@ -1601,9 +1585,6 @@ void Reflection::ListFieldsMayFailOnStripped(
};
for (int i = 0; i <= last_non_weak_field_index; i++) {
const FieldDescriptor* field = descriptor_->field(i);
if (!should_fail && schema_.IsFieldStripped(field)) {
continue;
}
if (field->is_repeated()) {
if (FieldSize(message, field) > 0) {
append_to_output(field);
Expand All @@ -1620,7 +1601,6 @@ void Reflection::ListFieldsMayFailOnStripped(
append_to_output(field);
}
} else if (has_bits && has_bits_indices[i] != static_cast<uint32_t>(-1)) {
CheckInvalidAccess(schema_, field);
// Equivalent to: HasBit(message, field)
if (IsIndexInHasBitSet(has_bits, has_bits_indices[i])) {
append_to_output(field);
Expand Down Expand Up @@ -1657,16 +1637,6 @@ void Reflection::ListFieldsMayFailOnStripped(
}
}

void Reflection::ListFields(const Message& message,
std::vector<const FieldDescriptor*>* output) const {
ListFieldsMayFailOnStripped(message, true, output);
}

void Reflection::ListFieldsOmitStripped(
const Message& message, std::vector<const FieldDescriptor*>* output) const {
ListFieldsMayFailOnStripped(message, false, output);
}

// -------------------------------------------------------------------

#undef DEFINE_PRIMITIVE_ACCESSORS
Expand Down Expand Up @@ -2089,7 +2059,6 @@ const Message& Reflection::GetMessage(const Message& message,
const FieldDescriptor* field,
MessageFactory* factory) const {
USAGE_CHECK_ALL(GetMessage, SINGULAR, MESSAGE);
CheckInvalidAccess(schema_, field);

if (factory == nullptr) factory = message_factory_;

Expand All @@ -2112,7 +2081,6 @@ Message* Reflection::MutableMessage(Message* message,
const FieldDescriptor* field,
MessageFactory* factory) const {
USAGE_CHECK_ALL(MutableMessage, SINGULAR, MESSAGE);
CheckInvalidAccess(schema_, field);

if (factory == nullptr) factory = message_factory_;

Expand Down Expand Up @@ -2148,7 +2116,6 @@ void Reflection::UnsafeArenaSetAllocatedMessage(
Message* message, Message* sub_message,
const FieldDescriptor* field) const {
USAGE_CHECK_ALL(SetAllocatedMessage, SINGULAR, MESSAGE);
CheckInvalidAccess(schema_, field);


if (field->is_extension()) {
Expand Down Expand Up @@ -2184,7 +2151,6 @@ void Reflection::SetAllocatedMessage(Message* message, Message* sub_message,
GOOGLE_ABSL_DCHECK(
sub_message == nullptr || sub_message->GetOwningArena() == nullptr ||
sub_message->GetOwningArena() == message->GetArenaForAllocation());
CheckInvalidAccess(schema_, field);

// If message and sub-message are in different memory ownership domains
// (different arenas, or one is on heap and one is not), then we may need to
Expand Down Expand Up @@ -2215,7 +2181,6 @@ Message* Reflection::UnsafeArenaReleaseMessage(Message* message,
const FieldDescriptor* field,
MessageFactory* factory) const {
USAGE_CHECK_ALL(ReleaseMessage, SINGULAR, MESSAGE);
CheckInvalidAccess(schema_, field);

if (factory == nullptr) factory = message_factory_;

Expand Down Expand Up @@ -2244,8 +2209,6 @@ Message* Reflection::UnsafeArenaReleaseMessage(Message* message,
Message* Reflection::ReleaseMessage(Message* message,
const FieldDescriptor* field,
MessageFactory* factory) const {
CheckInvalidAccess(schema_, field);

Message* released = UnsafeArenaReleaseMessage(message, field, factory);
#ifdef PROTOBUF_FORCE_COPY_IN_RELEASE
released = MaybeForceCopy(message->GetArenaForAllocation(), released);
Expand All @@ -2262,7 +2225,6 @@ const Message& Reflection::GetRepeatedMessage(const Message& message,
const FieldDescriptor* field,
int index) const {
USAGE_CHECK_ALL(GetRepeatedMessage, REPEATED, MESSAGE);
CheckInvalidAccess(schema_, field);

if (field->is_extension()) {
return static_cast<const Message&>(
Expand All @@ -2283,7 +2245,6 @@ Message* Reflection::MutableRepeatedMessage(Message* message,
const FieldDescriptor* field,
int index) const {
USAGE_CHECK_ALL(MutableRepeatedMessage, REPEATED, MESSAGE);
CheckInvalidAccess(schema_, field);

if (field->is_extension()) {
return static_cast<Message*>(
Expand All @@ -2304,7 +2265,6 @@ Message* Reflection::MutableRepeatedMessage(Message* message,
Message* Reflection::AddMessage(Message* message, const FieldDescriptor* field,
MessageFactory* factory) const {
USAGE_CHECK_ALL(AddMessage, REPEATED, MESSAGE);
CheckInvalidAccess(schema_, field);

if (factory == nullptr) factory = message_factory_;

Expand Down Expand Up @@ -2347,7 +2307,6 @@ void Reflection::AddAllocatedMessage(Message* message,
const FieldDescriptor* field,
Message* new_entry) const {
USAGE_CHECK_ALL(AddAllocatedMessage, REPEATED, MESSAGE);
CheckInvalidAccess(schema_, field);

if (field->is_extension()) {
MutableExtensionSet(message)->AddAllocatedMessage(field, new_entry);
Expand All @@ -2367,7 +2326,6 @@ void Reflection::UnsafeArenaAddAllocatedMessage(Message* message,
const FieldDescriptor* field,
Message* new_entry) const {
USAGE_CHECK_ALL(UnsafeArenaAddAllocatedMessage, REPEATED, MESSAGE);
CheckInvalidAccess(schema_, field);

if (field->is_extension()) {
MutableExtensionSet(message)->UnsafeArenaAddAllocatedMessage(field,
Expand All @@ -2391,7 +2349,6 @@ void* Reflection::MutableRawRepeatedField(Message* message,
const Descriptor* desc) const {
(void)ctype; // Parameter is used by Google-internal code.
USAGE_CHECK_REPEATED("MutableRawRepeatedField");
CheckInvalidAccess(schema_, field);

if (field->cpp_type() != cpptype &&
(field->cpp_type() != FieldDescriptor::CPPTYPE_ENUM ||
Expand Down Expand Up @@ -2690,9 +2647,6 @@ bool Reflection::HasBit(const Message& message,
return IsIndexInHasBitSet(GetHasBits(message), schema_.HasBitIndex(field));
}

// Intentionally check here because HasBitIndex(field) != -1 means valid.
CheckInvalidAccess(schema_, field);

// proto3: no has-bits. All fields present except messages, which are
// present only if their message-field pointer is non-null.
if (field->cpp_type() == FieldDescriptor::CPPTYPE_MESSAGE) {
Expand Down Expand Up @@ -3277,8 +3231,6 @@ const internal::TcParseTableBase* Reflection::CreateTcParseTable() const {
std::vector<int> inlined_string_indices = has_bit_indices;
for (int i = 0; i < descriptor_->field_count(); ++i) {
auto* field = descriptor_->field(i);
if (schema_.IsFieldStripped(field)) continue;

fields.push_back(field);
has_bit_indices[static_cast<size_t>(field->index())] =
static_cast<int>(schema_.HasBitIndex(field));
Expand Down Expand Up @@ -3624,7 +3576,7 @@ void UnknownFieldSetSerializer(const uint8_t* base, uint32_t offset,
bool IsDescendant(Message& root, const Message& message) {
const Reflection* reflection = root.GetReflection();
std::vector<const FieldDescriptor*> fields;
reflection->ListFieldsOmitStripped(root, &fields);
reflection->ListFields(root, &fields);

for (const auto* field : fields) {
// Skip non-message fields.
Expand Down
10 changes: 0 additions & 10 deletions src/google/protobuf/generated_message_reflection.h
Original file line number Diff line number Diff line change
Expand Up @@ -236,16 +236,6 @@ struct ReflectionSchema {
return false;
}

bool IsFieldStripped(const FieldDescriptor* field) const {
(void)field;
return false;
}

bool IsMessageStripped(const Descriptor* descriptor) const {
(void)descriptor;
return false;
}

bool IsSplit() const { return split_offset_ != -1; }

bool IsSplit(const FieldDescriptor* field) const {
Expand Down
16 changes: 0 additions & 16 deletions src/google/protobuf/message.h
Original file line number Diff line number Diff line change
Expand Up @@ -1024,22 +1024,6 @@ class PROTOBUF_EXPORT Reflection final {
const internal::RepeatedFieldAccessor* RepeatedFieldAccessor(
const FieldDescriptor* field) const;

// Lists all fields of the message which are currently set, except for unknown
// fields and stripped fields. See ListFields for details.
void ListFieldsOmitStripped(
const Message& message,
std::vector<const FieldDescriptor*>* output) const;

bool IsMessageStripped(const Descriptor* descriptor) const {
return schema_.IsMessageStripped(descriptor);
}

friend class TextFormat;

void ListFieldsMayFailOnStripped(
const Message& message, bool should_fail,
std::vector<const FieldDescriptor*>* output) const;

// Returns true if the message field is backed by a LazyField.
//
// A message field may be backed by a LazyField without the user annotation
Expand Down
8 changes: 4 additions & 4 deletions src/google/protobuf/reflection_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ void ReflectionOps::Merge(const Message& from, Message* to) {
google::protobuf::MessageFactory::generated_factory());

std::vector<const FieldDescriptor*> fields;
from_reflection->ListFieldsOmitStripped(from, &fields);
from_reflection->ListFields(from, &fields);
for (const FieldDescriptor* field : fields) {
if (field->is_repeated()) {
// Use map reflection if both are in map status and have the
Expand Down Expand Up @@ -182,7 +182,7 @@ void ReflectionOps::Clear(Message* message) {
const Reflection* reflection = GetReflectionOrDie(*message);

std::vector<const FieldDescriptor*> fields;
reflection->ListFieldsOmitStripped(*message, &fields);
reflection->ListFields(*message, &fields);
for (const FieldDescriptor* field : fields) {
reflection->ClearField(message, field);
}
Expand Down Expand Up @@ -274,7 +274,7 @@ bool ReflectionOps::IsInitialized(const Message& message) {
std::vector<const FieldDescriptor*> fields;
// Should be safe to skip stripped fields because required fields are not
// stripped.
reflection->ListFieldsOmitStripped(message, &fields);
reflection->ListFields(message, &fields);
for (const FieldDescriptor* field : fields) {
if (field->cpp_type() == FieldDescriptor::CPPTYPE_MESSAGE) {

Expand Down Expand Up @@ -403,7 +403,7 @@ void ReflectionOps::FindInitializationErrors(const Message& message,

// Check sub-messages.
std::vector<const FieldDescriptor*> fields;
reflection->ListFieldsOmitStripped(message, &fields);
reflection->ListFields(message, &fields);
for (const FieldDescriptor* field : fields) {
if (field->cpp_type() == FieldDescriptor::CPPTYPE_MESSAGE) {

Expand Down
5 changes: 1 addition & 4 deletions src/google/protobuf/text_format.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2252,10 +2252,7 @@ void TextFormat::Printer::PrintMessage(const Message& message,
fields.push_back(descriptor->field(0));
fields.push_back(descriptor->field(1));
} else {
reflection->ListFieldsOmitStripped(message, &fields);
if (reflection->IsMessageStripped(message.GetDescriptor())) {
generator->Print(kDoNotParse, std::strlen(kDoNotParse));
}
reflection->ListFields(message, &fields);
}

if (print_message_fields_in_index_order_) {
Expand Down

0 comments on commit 2d6ee17

Please sign in to comment.