diff --git a/CHANGELOG.md b/CHANGELOG.md index 2a08a41e90..cc7883b161 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -48,6 +48,8 @@ different versioning scheme, following the Haskell community's * gRPC v1.7.1 is now required to use Bond-over-gRPC. * Fixed includes for gRPC services with events or parameterless methods. [Issue #735](https://github.com/Microsoft/bond/issues/735) +* Fixed a bug which would read an unrelated struct's field(s) when deserializing a + base struct. [Issue #742](https://github.com/Microsoft/bond/issues/742) ### C# ### diff --git a/cpp/inc/bond/core/detail/inheritance.h b/cpp/inc/bond/core/detail/inheritance.h index a3c4deed3f..b7e401995e 100644 --- a/cpp/inc/bond/core/detail/inheritance.h +++ b/cpp/inc/bond/core/detail/inheritance.h @@ -101,10 +101,17 @@ class ParserInheritance // First we recurse into base structs (serialized data starts at the top of the hierarchy) // and then we read to the transform the fields of the top level struct. transform.Begin(T::metadata); - ReadBase(base_class(), transform); - bool result = static_cast(this)->ReadFields(typename boost::mpl::begin::type(), transform); + + bool done = ReadBase(base_class(), transform); + + if (!done) + { + done = static_cast(this)->ReadFields(typename boost::mpl::begin::type(), transform); + } + transform.End(); - return result; + + return done; } @@ -138,7 +145,7 @@ class ParserInheritance bool Read(const RuntimeSchema& schema, const Transform& transform) { // The logic is the same as for compile-time schemas, described in the comments above. - bool result; + bool done; typename base_input::type base(base_input::from(_input)); @@ -148,7 +155,7 @@ class ParserInheritance detail::StructBegin(_input, true); - result = Parser(base, _base).Read(schema.GetBaseSchema(), transform); + done = Parser(base, _base).Read(schema.GetBaseSchema(), transform); detail::StructEnd(_input, true); @@ -158,14 +165,17 @@ class ParserInheritance { transform.Begin(schema.GetStruct().metadata); - if (schema.HasBase()) - transform.Base(bonded(base, schema.GetBaseSchema(), true)); + done = schema.HasBase() && transform.Base(bonded(base, schema.GetBaseSchema(), true)); + + if (!done) + { + done = static_cast(this)->ReadFields(schema, transform); + } - result = static_cast(this)->ReadFields(schema, transform); transform.End(); } - return result; + return done; } Input _input; diff --git a/cpp/inc/bond/core/parser.h b/cpp/inc/bond/core/parser.h index 4f17db085b..3cfbb252e5 100644 --- a/cpp/inc/bond/core/parser.h +++ b/cpp/inc/bond/core/parser.h @@ -271,6 +271,8 @@ class DynamicParser ReadFields(fields, id, type, transform); + bool done; + if (!_base) { // If we are not parsing a base class, and we still didn't get to @@ -292,11 +294,17 @@ class DynamicParser else UnknownField(id, type, transform); } + + done = false; + } + else + { + done = (type == bond::BT_STOP); } _input.ReadFieldEnd(); - return false; + return done; } @@ -469,6 +477,8 @@ class DynamicParser UnknownField(id, type, transform); } + bool done; + if (!_base) { // If we are not parsing a base class, and we still didn't get to @@ -494,11 +504,17 @@ class DynamicParser UnknownField(id, type, transform); } } + + done = false; + } + else + { + done = (type == bond::BT_STOP); } _input.ReadFieldEnd(); - return false; + return done; } diff --git a/cpp/inc/bond/core/transforms.h b/cpp/inc/bond/core/transforms.h index 1ace8114ee..6b9bc21369 100644 --- a/cpp/inc/bond/core/transforms.h +++ b/cpp/inc/bond/core/transforms.h @@ -432,6 +432,9 @@ class RequiredFieldValiadator template void RequiredFieldValiadator::MissingFieldException() const { + // Force instantiation of template statics + (void)typename schema::type(); + BOND_THROW(CoreException, "De-serialization failed: required field " << _required << " is missing from " << schema::type::metadata.qualified_name); @@ -530,7 +533,12 @@ class To template bool Base(const X& value) const { - return AssignToBase(_var, value); + if (AssignToBase(_var, value)) + { + UnexpectedStructStopException(); + } + + return false; } @@ -593,6 +601,16 @@ class To } private: + BOND_NORETURN void UnexpectedStructStopException() const + { + // Force instantiation of template statics + (void)typename schema::type(); + + BOND_THROW(CoreException, + "De-serialization failed: unexpected struct stop encountered for " + << schema::type::metadata.qualified_name); + } + T& _var; }; diff --git a/cpp/test/core/inheritance_test.cpp b/cpp/test/core/inheritance_test.cpp index ab3af29758..a382e7a70b 100644 --- a/cpp/test/core/inheritance_test.cpp +++ b/cpp/test/core/inheritance_test.cpp @@ -10,6 +10,34 @@ bool Compare(const ListWithBase& left, const ListOfBase& right) && Equal(left.v4, right.v4); } +template +typename boost::enable_if >::type +DeserializeBaseToDerived() +{ + ListOfBondedBase obj; + GetBonded(InitRandom()).Deserialize(obj); + + for (const auto& base : obj.l) + { + bond::bonded derived_bonded(base); + StructWithBase derived; + BOOST_CHECK_THROW(derived_bonded.Deserialize(derived), bond::CoreException); + BOOST_CHECK_THROW(bond::bonded(derived_bonded).Deserialize(derived), bond::CoreException); + } +} + +template +typename boost::disable_if >::type +DeserializeBaseToDerived() +{} + +template +TEST_CASE_BEGIN(BaseToDerivedDeserializationTest) +{ + DeserializeBaseToDerived(); +} +TEST_CASE_END + template void InheritanceTests(const char* name) { @@ -50,6 +78,9 @@ void InheritanceTests(const char* name) // Deserialize as containers of base/partial hierarchy AddTestCase(suite, "Containers, partial hierarchy"); + + AddTestCase(suite, "Base to derived deserialization"); } diff --git a/cpp/test/core/unit_test.bond b/cpp/test/core/unit_test.bond index 4f8ea82a3d..5a52fea3b0 100644 --- a/cpp/test/core/unit_test.bond +++ b/cpp/test/core/unit_test.bond @@ -97,6 +97,12 @@ struct ListOfBase }; +struct ListOfBondedBase +{ + 1: list> l; +} + + struct NestedStruct1OptionalBondedView { 1: bonded s; diff --git a/python/test/core/unit_test.py b/python/test/core/unit_test.py index 5c9fe1c89a..4daa245ce8 100644 --- a/python/test/core/unit_test.py +++ b/python/test/core/unit_test.py @@ -389,11 +389,6 @@ def test_Bonded(self): src2.n2 = self.randomSimpleWithBase() dst2 = test.Bonded() Deserialize(Serialize(src), dst2) - # downcast bonded to bonded - bonded = test.bonded_unittest_SimpleWithBase_(dst2.n2) - obj3 = test.SimpleWithBase() - bonded.Deserialize(obj3) - self.assertTrue(obj3, src2.n2) def test_Polymorphism(self): src = test.Bonded()