Skip to content

Commit

Permalink
fix: additional logic to mitigate collisions with reserved terms (#301)
Browse files Browse the repository at this point in the history
* fix: additional logic to mitigate collisions with reserved terms

* address review feedback

* remove line break

* remove redundant code

* add assert
  • Loading branch information
parthea authored Feb 18, 2022
1 parent 06d0620 commit c9a77df
Show file tree
Hide file tree
Showing 2 changed files with 114 additions and 25 deletions.
90 changes: 67 additions & 23 deletions proto/message.py
Original file line number Diff line number Diff line change
Expand Up @@ -527,34 +527,72 @@ def __init__(
# coerced.
marshal = self._meta.marshal
for key, value in mapping.items():
(key, pb_type) = self._get_pb_type_from_key(key)
if pb_type is None:
if ignore_unknown_fields:
continue

raise ValueError(
"Unknown field for {}: {}".format(self.__class__.__name__, key)
)

try:
pb_type = self._meta.fields[key].pb_type
except KeyError:
pb_value = marshal.to_proto(pb_type, value)
except ValueError:
# Underscores may be appended to field names
# that collide with python or proto-plus keywords.
# In case a key only exists with a `_` suffix, coerce the key
# to include the `_` suffix. Is not possible to
# to include the `_` suffix. It's not possible to
# natively define the same field with a trailing underscore in protobuf.
# See related issue
# https://github.com/googleapis/python-api-core/issues/227
if f"{key}_" in self._meta.fields:
key = f"{key}_"
pb_type = self._meta.fields[key].pb_type
else:
if ignore_unknown_fields:
continue

raise ValueError(
"Unknown field for {}: {}".format(self.__class__.__name__, key)
)

pb_value = marshal.to_proto(pb_type, value)
if isinstance(value, dict):
keys_to_update = [
item
for item in value
if not hasattr(pb_type, item) and hasattr(pb_type, f"{item}_")
]
for item in keys_to_update:
value[f"{item}_"] = value.pop(item)

pb_value = marshal.to_proto(pb_type, value)

if pb_value is not None:
params[key] = pb_value

# Create the internal protocol buffer.
super().__setattr__("_pb", self._meta.pb(**params))

def _get_pb_type_from_key(self, key):
"""Given a key, return the corresponding pb_type.
Args:
key(str): The name of the field.
Returns:
A tuple containing a key and pb_type. The pb_type will be
the composite type of the field, or the primitive type if a primitive.
If no corresponding field exists, return None.
"""

pb_type = None

try:
pb_type = self._meta.fields[key].pb_type
except KeyError:
# Underscores may be appended to field names
# that collide with python or proto-plus keywords.
# In case a key only exists with a `_` suffix, coerce the key
# to include the `_` suffix. It's not possible to
# natively define the same field with a trailing underscore in protobuf.
# See related issue
# https://github.com/googleapis/python-api-core/issues/227
if f"{key}_" in self._meta.fields:
key = f"{key}_"
pb_type = self._meta.fields[key].pb_type

return (key, pb_type)

def __dir__(self):
desc = type(self).pb().DESCRIPTOR
names = {f_name for f_name in self._meta.fields.keys()}
Expand Down Expand Up @@ -664,13 +702,14 @@ def __getattr__(self, key):
their Python equivalents. See the ``marshal`` module for
more details.
"""
try:
pb_type = self._meta.fields[key].pb_type
pb_value = getattr(self._pb, key)
marshal = self._meta.marshal
return marshal.to_python(pb_type, pb_value, absent=key not in self)
except KeyError as ex:
raise AttributeError(str(ex))
(key, pb_type) = self._get_pb_type_from_key(key)
if pb_type is None:
raise AttributeError(
"Unknown field for {}: {}".format(self.__class__.__name__, key)
)
pb_value = getattr(self._pb, key)
marshal = self._meta.marshal
return marshal.to_python(pb_type, pb_value, absent=key not in self)

def __ne__(self, other):
"""Return True if the messages are unequal, False otherwise."""
Expand All @@ -688,7 +727,12 @@ def __setattr__(self, key, value):
if key[0] == "_":
return super().__setattr__(key, value)
marshal = self._meta.marshal
pb_type = self._meta.fields[key].pb_type
(key, pb_type) = self._get_pb_type_from_key(key)
if pb_type is None:
raise AttributeError(
"Unknown field for {}: {}".format(self.__class__.__name__, key)
)

pb_value = marshal.to_proto(pb_type, value)

# Clear the existing field.
Expand Down
49 changes: 47 additions & 2 deletions tests/test_fields_mitigate_collision.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,12 @@
# limitations under the License.

import proto

import pytest

# Underscores may be appended to field names
# that collide with python or proto-plus keywords.
# In case a key only exists with a `_` suffix, coerce the key
# to include the `_` suffix. Is not possible to
# to include the `_` suffix. It's not possible to
# natively define the same field with a trailing underscore in protobuf.
# See related issue
# https://github.com/googleapis/python-api-core/issues/227
Expand All @@ -27,10 +27,55 @@ class TestMessage(proto.Message):
spam_ = proto.Field(proto.STRING, number=1)
eggs = proto.Field(proto.STRING, number=2)

class TextStream(proto.Message):
text_stream = proto.Field(TestMessage, number=1)

obj = TestMessage(spam_="has_spam")
obj.eggs = "has_eggs"
assert obj.spam_ == "has_spam"

# Test that `spam` is coerced to `spam_`
modified_obj = TestMessage({"spam": "has_spam", "eggs": "has_eggs"})
assert modified_obj.spam_ == "has_spam"

# Test get and set
modified_obj.spam = "no_spam"
assert modified_obj.spam == "no_spam"

modified_obj.spam_ = "yes_spam"
assert modified_obj.spam_ == "yes_spam"

modified_obj.spam = "maybe_spam"
assert modified_obj.spam_ == "maybe_spam"

modified_obj.spam_ = "maybe_not_spam"
assert modified_obj.spam == "maybe_not_spam"

# Try nested values
modified_obj = TextStream(
text_stream=TestMessage({"spam": "has_spam", "eggs": "has_eggs"})
)
assert modified_obj.text_stream.spam_ == "has_spam"

# Test get and set for nested values
modified_obj.text_stream.spam = "no_spam"
assert modified_obj.text_stream.spam == "no_spam"

modified_obj.text_stream.spam_ = "yes_spam"
assert modified_obj.text_stream.spam_ == "yes_spam"

modified_obj.text_stream.spam = "maybe_spam"
assert modified_obj.text_stream.spam_ == "maybe_spam"

modified_obj.text_stream.spam_ = "maybe_not_spam"
assert modified_obj.text_stream.spam == "maybe_not_spam"

with pytest.raises(AttributeError):
assert modified_obj.text_stream.attribute_does_not_exist == "n/a"

with pytest.raises(AttributeError):
modified_obj.text_stream.attribute_does_not_exist = "n/a"

# Try using dict
modified_obj = TextStream(text_stream={"spam": "has_spam", "eggs": "has_eggs"})
assert modified_obj.text_stream.spam_ == "has_spam"

0 comments on commit c9a77df

Please sign in to comment.