diff --git a/upb/mini_table_accessors.c b/upb/mini_table_accessors.c index 325062ddeb..bb99d25537 100644 --- a/upb/mini_table_accessors.c +++ b/upb/mini_table_accessors.c @@ -33,6 +33,7 @@ #include "upb/mini_table.h" // Must be last. +#include "upb/msg.h" #include "upb/port_def.inc" size_t upb_MiniTable_Field_GetSize(const upb_MiniTable_Field* f) { @@ -172,6 +173,37 @@ static const char* decode_tag(const char* ptr, uint32_t* val) { } } +// Parses unknown data by merging into existing base_message or creating a +// new message usingg mini_table. +static upb_UnknownToMessageRet upb_MiniTable_ParseUnknownMessage( + const char* unknown_data, size_t unknown_size, + const upb_MiniTable* mini_table, upb_Message* base_message, + int decode_options, upb_Arena* arena) { + upb_UnknownToMessageRet ret; + ret.message = + base_message ? base_message : _upb_Message_New(mini_table, arena); + if (!ret.message) { + ret.status = kUpb_UnknownToMessage_OutOfMemory; + return ret; + } + // Decode sub message using unknown field contents. + const char* data = unknown_data; + uint32_t tag; + uint64_t message_len = 0; + data = decode_tag(data, &tag); + data = decode_varint64(data, &message_len); + upb_DecodeStatus status = upb_Decode(data, message_len, ret.message, + mini_table, NULL, decode_options, arena); + if (status == kUpb_DecodeStatus_OutOfMemory) { + ret.status = kUpb_UnknownToMessage_OutOfMemory; + } else if (status == kUpb_DecodeStatus_Ok) { + ret.status = kUpb_UnknownToMessage_Ok; + } else { + ret.status = kUpb_UnknownToMessage_ParseError; + } + return ret; +} + upb_GetExtension_Status upb_MiniTable_GetOrPromoteExtension( upb_Message* msg, const upb_MiniTable_Extension* ext_table, int decode_options, upb_Arena* arena, @@ -190,22 +222,20 @@ upb_GetExtension_Status upb_MiniTable_GetOrPromoteExtension( } // Decode and promote from unknown. const upb_MiniTable* extension_table = ext_table->sub.submsg; - upb_Message* extension_msg = _upb_Message_New(extension_table, arena); - if (!extension_msg) { - return kUpb_GetExtension_OutOfMemory; - } - const char* data = result.ptr; - uint32_t tag; - uint64_t message_len = 0; - data = decode_tag(data, &tag); - data = decode_varint64(data, &message_len); - upb_DecodeStatus status = - upb_Decode(data, message_len, extension_msg, extension_table, NULL, - decode_options, arena); - if (status == kUpb_DecodeStatus_OutOfMemory) { - return kUpb_GetExtension_OutOfMemory; + upb_UnknownToMessageRet parse_result = upb_MiniTable_ParseUnknownMessage( + result.ptr, result.len, extension_table, + /* base_message= */ NULL, decode_options, arena); + switch (parse_result.status) { + case kUpb_UnknownToMessage_OutOfMemory: + return kUpb_GetExtension_OutOfMemory; + case kUpb_UnknownToMessage_ParseError: + return kUpb_GetExtension_ParseError; + case kUpb_UnknownToMessage_NotFound: + return kUpb_GetExtension_NotPresent; + case kUpb_UnknownToMessage_Ok: + break; } - if (status != kUpb_DecodeStatus_Ok) return kUpb_GetExtension_ParseError; + upb_Message* extension_msg = parse_result.message; // Add to extensions. upb_Message_Extension* ext = _upb_Message_GetOrCreateExtension(msg, ext_table, arena); @@ -214,15 +244,7 @@ upb_GetExtension_Status upb_MiniTable_GetOrPromoteExtension( } memcpy(&ext->data, &extension_msg, sizeof(extension_msg)); *extension = ext; - // Remove unknown field. - upb_Message_Internal* in = upb_Message_Getinternal(msg); - const char* internal_unknown_end = - UPB_PTR_AT(in->internal, in->internal->unknown_end, char); - if ((result.ptr + result.len) != internal_unknown_end) { - memmove((char*)result.ptr, result.ptr + result.len, - internal_unknown_end - result.ptr - result.len); - } - in->internal->unknown_end -= result.len; + upb_Message_DeleteUnknown(msg, result.ptr, result.len); return kUpb_GetExtension_Ok; } @@ -382,3 +404,79 @@ upb_FindUnknownRet upb_MiniTable_FindUnknown(const upb_Message* msg, ret.len = 0; return ret; } + +upb_UnknownToMessageRet upb_MiniTable_PromoteUnknownToMessage( + upb_Message* msg, const upb_MiniTable* mini_table, + const upb_MiniTable_Field* field, const upb_MiniTable* sub_mini_table, + int decode_options, upb_Arena* arena) { + upb_FindUnknownRet unknown; + // We need to loop and merge unknowns that have matching tag field->number. + upb_Message* message = NULL; + // Callers should check that message is not set first before calling + // PromotoUnknownToMessage. + UPB_ASSERT(upb_MiniTable_GetMessage(msg, field) == NULL); + upb_UnknownToMessageRet ret; + ret.status = kUpb_UnknownToMessage_Ok; + do { + unknown = upb_MiniTable_FindUnknown(msg, field->number); + switch (unknown.status) { + case kUpb_FindUnknown_Ok: { + const char* unknown_data = unknown.ptr; + size_t unknown_size = unknown.len; + ret = upb_MiniTable_ParseUnknownMessage(unknown_data, unknown_size, + sub_mini_table, message, + decode_options, arena); + if (ret.status == kUpb_UnknownToMessage_Ok) { + message = ret.message; + upb_Message_DeleteUnknown(msg, unknown_data, unknown_size); + } + } break; + case kUpb_FindUnknown_ParseError: + ret.status = kUpb_UnknownToMessage_ParseError; + break; + case kUpb_FindUnknown_NotPresent: + // If we parsed at least one unknown, we are done. + ret.status = + message ? kUpb_UnknownToMessage_Ok : kUpb_UnknownToMessage_NotFound; + break; + } + } while (unknown.status == kUpb_FindUnknown_Ok); + if (message) { + upb_MiniTable_SetMessage(msg, mini_table, field, message); + ret.message = message; + } + return ret; +} + +// Moves repeated messages in unknowns to a upb_Array. +// +// Since the repeated field is not a scalar type we don't check for +// kUpb_LabelFlags_IsPacked. +// TODO(b/251007554): Optimize. Instead of converting messages one at a time, +// scan all unknown data once and compact. +upb_UnknownToMessage_Status upb_MiniTable_PromoteUnknownToMessageArray( + upb_Message* msg, const upb_MiniTable_Field* field, + const upb_MiniTable* mini_table, int decode_options, upb_Arena* arena) { + upb_Array* repeated_messages = upb_MiniTable_GetMutableArray(msg, field); + // Find all unknowns with given field number and parse. + upb_FindUnknownRet unknown; + do { + unknown = upb_MiniTable_FindUnknown(msg, field->number); + if (unknown.status == kUpb_FindUnknown_Ok) { + upb_UnknownToMessageRet ret = upb_MiniTable_ParseUnknownMessage( + unknown.ptr, unknown.len, mini_table, + /* base_message= */ NULL, decode_options, arena); + if (ret.status == kUpb_UnknownToMessage_Ok) { + upb_MessageValue value; + value.msg_val = ret.message; + if (!upb_Array_Append(repeated_messages, value, arena)) { + return kUpb_UnknownToMessage_OutOfMemory; + } + upb_Message_DeleteUnknown(msg, unknown.ptr, unknown.len); + } else { + return ret.status; + } + } + } while (unknown.status == kUpb_FindUnknown_Ok); + return kUpb_UnknownToMessage_Ok; +} diff --git a/upb/mini_table_accessors.h b/upb/mini_table_accessors.h index 0c3ec7cf3b..bb509513c0 100644 --- a/upb/mini_table_accessors.h +++ b/upb/mini_table_accessors.h @@ -31,7 +31,6 @@ #include "upb/array.h" #include "upb/internal/mini_table_accessors.h" #include "upb/mini_table.h" -#include "upb/msg_internal.h" // Must be last. #include "upb/port_def.inc" @@ -279,9 +278,41 @@ typedef struct { size_t len; } upb_FindUnknownRet; +// Finds first occurrence of unknown data by tag id in message. upb_FindUnknownRet upb_MiniTable_FindUnknown(const upb_Message* msg, uint32_t field_number); +typedef enum { + kUpb_UnknownToMessage_Ok, + kUpb_UnknownToMessage_ParseError, + kUpb_UnknownToMessage_OutOfMemory, + kUpb_UnknownToMessage_NotFound, +} upb_UnknownToMessage_Status; + +typedef struct { + upb_UnknownToMessage_Status status; + upb_Message* message; +} upb_UnknownToMessageRet; + +// Promotes unknown data inside message to a upb_Message parsing the unknown. +// +// The unknown data is removed from message after field value is set +// using upb_MiniTable_SetMessage. +upb_UnknownToMessageRet upb_MiniTable_PromoteUnknownToMessage( + upb_Message* msg, const upb_MiniTable* mini_table, + const upb_MiniTable_Field* field, const upb_MiniTable* sub_mini_table, + int decode_options, upb_Arena* arena); + +// Promotes all unknown data that matches field tag id to repeated messages +// in upb_Array. +// +// The unknown data is removed from message after upb_Array is populated. +// Since repeated messages can't be packed we remove each unknown that +// contains the target tag id. +upb_UnknownToMessage_Status upb_MiniTable_PromoteUnknownToMessageArray( + upb_Message* msg, const upb_MiniTable_Field* field, + const upb_MiniTable* mini_table, int decode_options, upb_Arena* arena); + #ifdef __cplusplus } /* extern "C" */ #endif diff --git a/upb/mini_table_accessors_test.cc b/upb/mini_table_accessors_test.cc index 77bfb4417c..65873dc9dc 100644 --- a/upb/mini_table_accessors_test.cc +++ b/upb/mini_table_accessors_test.cc @@ -37,8 +37,11 @@ #include "google/protobuf/test_messages_proto2.upb.h" #include "google/protobuf/test_messages_proto3.upb.h" #include "upb/array.h" +#include "upb/decode.h" #include "upb/mini_table.h" +#include "upb/msg_internal.h" #include "upb/test.upb.h" +#include "upb/upb.h" namespace { @@ -474,4 +477,77 @@ TEST(GeneratedCode, Extensions) { upb_Arena_Free(arena); } +// Create a minitable to mimic ModelWithSubMessages with unlinked subs +// to lazily promote unknowns after parsing. +upb_MiniTable* CreateMiniTableWithEmptySubTables(upb_Arena* arena) { + upb_MtDataEncoder e; + const size_t kBufferSize = 256; + char buf[kBufferSize]; + char* ptr = buf; + e.end = ptr + kBufferSize; + ptr = upb_MtDataEncoder_StartMessage(&e, ptr, /* msg_mod= */ 0); + EXPECT_TRUE(ptr != nullptr); + ptr = upb_MtDataEncoder_PutField(&e, ptr, kUpb_FieldType_Int32, 4, 0); + ptr = upb_MtDataEncoder_PutField(&e, ptr, kUpb_FieldType_Message, 5, 0); + ptr = upb_MtDataEncoder_PutField(&e, ptr, kUpb_FieldType_Message, 6, + kUpb_FieldModifier_IsRepeated); + + upb_Status status; + upb_Status_Clear(&status); + upb_MiniTable* table = upb_MiniTable_Build( + buf, ptr - buf, kUpb_MiniTablePlatform_Native, arena, &status); + EXPECT_EQ(status.ok, true); + // Initialize sub table to null. Not using upb_MiniTable_SetSubMessage + // since it checks ->ext on parameter. + upb_MiniTable_Sub* sub = const_cast( + &table->subs[table->fields[1].submsg_index]); + sub->submsg = nullptr; + sub = const_cast( + &table->subs[table->fields[2].submsg_index]); + sub->submsg = nullptr; + return table; +} + +TEST(GeneratedCode, PromoteUnknownMessage) { + upb_Arena* arena = upb_Arena_New(); + upb_test_ModelWithSubMessages* input_msg = + upb_test_ModelWithSubMessages_new(arena); + upb_test_ModelWithExtensions* sub_message = + upb_test_ModelWithExtensions_new(arena); + upb_test_ModelWithSubMessages_set_id(input_msg, 11); + upb_test_ModelWithExtensions_set_random_int32(sub_message, 12); + upb_test_ModelWithSubMessages_set_optional_child(input_msg, sub_message); + size_t serialized_size; + char* serialized = upb_test_ModelWithSubMessages_serialize(input_msg, arena, + &serialized_size); + + upb_MiniTable* mini_table = CreateMiniTableWithEmptySubTables(arena); + upb_Message* msg = _upb_Message_New(mini_table, arena); + upb_DecodeStatus decode_status = upb_Decode(serialized, serialized_size, msg, + mini_table, nullptr, 0, arena); + EXPECT_EQ(decode_status, kUpb_DecodeStatus_Ok); + int32_t val = upb_MiniTable_GetInt32( + msg, upb_MiniTable_FindFieldByNumber(mini_table, 4)); + EXPECT_EQ(val, 11); + upb_FindUnknownRet unknown = upb_MiniTable_FindUnknown(msg, 5); + EXPECT_EQ(unknown.status, kUpb_FindUnknown_Ok); + // Update mini table and promote unknown to a message. + upb_MiniTable_SetSubMessage(mini_table, + (upb_MiniTable_Field*)&mini_table->fields[1], + &upb_test_ModelWithExtensions_msg_init); + const int decode_options = + UPB_DECODE_MAXDEPTH(100); // UPB_DECODE_ALIAS disabled. + upb_UnknownToMessageRet promote_result = + upb_MiniTable_PromoteUnknownToMessage( + msg, mini_table, &mini_table->fields[1], + &upb_test_ModelWithExtensions_msg_init, decode_options, arena); + EXPECT_EQ(promote_result.status, kUpb_UnknownToMessage_Ok); + const upb_Message* promoted_message = + upb_MiniTable_GetMessage(msg, &mini_table->fields[1]); + EXPECT_EQ(upb_test_ModelWithExtensions_random_int32( + (upb_test_ModelWithExtensions*)promoted_message), + 12); + upb_Arena_Free(arena); +} + } // namespace diff --git a/upb/msg.c b/upb/msg.c index 53ea32217b..144b2125c5 100644 --- a/upb/msg.c +++ b/upb/msg.c @@ -112,6 +112,24 @@ const char* upb_Message_GetUnknown(const upb_Message* msg, size_t* len) { } } +void upb_Message_DeleteUnknown(upb_Message* msg, const char* data, size_t len) { + upb_Message_Internal* in = upb_Message_Getinternal(msg); + const char* internal_unknown_end = + UPB_PTR_AT(in->internal, in->internal->unknown_end, char); +#ifndef NDEBUG + size_t full_unknown_size; + const char* full_unknown = upb_Message_GetUnknown(msg, &full_unknown_size); + UPB_ASSERT((uintptr_t)data >= (uintptr_t)full_unknown); + UPB_ASSERT((uintptr_t)data < (uintptr_t)(full_unknown + full_unknown_size)); + UPB_ASSERT((uintptr_t)(data + len) > (uintptr_t)data); + UPB_ASSERT((uintptr_t)(data + len) <= (uintptr_t)internal_unknown_end); +#endif + if ((data + len) != internal_unknown_end) { + memmove((char*)data, data + len, internal_unknown_end - data - len); + } + in->internal->unknown_end -= len; +} + const upb_Message_Extension* _upb_Message_Getexts(const upb_Message* msg, size_t* count) { const upb_Message_Internal* in = upb_Message_Getinternal_const(msg); diff --git a/upb/msg.h b/upb/msg.h index cdb5c2bd13..166b877e81 100644 --- a/upb/msg.h +++ b/upb/msg.h @@ -59,6 +59,9 @@ void upb_Message_AddUnknown(upb_Message* msg, const char* data, size_t len, /* Returns a reference to the message's unknown data. */ const char* upb_Message_GetUnknown(const upb_Message* msg, size_t* len); +/* Removes partial unknown data from message. */ +void upb_Message_DeleteUnknown(upb_Message* msg, const char* data, size_t len); + /* Returns the number of extensions present in this message. */ size_t upb_Message_ExtensionCount(const upb_Message* msg); diff --git a/upb/test.proto b/upb/test.proto index d4e6b84488..ed9c1c0357 100644 --- a/upb/test.proto +++ b/upb/test.proto @@ -72,3 +72,9 @@ message ModelExtension2 { } optional int32 i = 9; } + +message ModelWithSubMessages { + optional int32 id = 4; + optional ModelWithExtensions optional_child = 5; + repeated ModelWithExtensions items = 6; +}