Skip to content

Commit

Permalink
[Python Schema] Fix redefined Record or Enum in Python schema (apache…
Browse files Browse the repository at this point in the history
…#11595)

Fixes apache#11533

### Motivation

Refer to issue apache#11533 , currently, if users redefined the same `Record` or `Enum` in `Record`, the schema info isn't reused the defined name, this does not match the Avro schema info format.

### Modifications

Add a new method `schema_info(self, defined_names)` in `Record`, `Array`, `Map`, and `Enum`, all defined names will be added in the parameter `defined_names` when users use a defined `Record`, or `Enum`, the schema info will use the name of the defined `Record` or `Enum` as the type.
  • Loading branch information
gaoran10 authored and ciaocloud committed Oct 16, 2021
1 parent fa36e9b commit 84a697a
Show file tree
Hide file tree
Showing 2 changed files with 83 additions and 15 deletions.
40 changes: 36 additions & 4 deletions pulsar-client-cpp/python/pulsar/schema/definition.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,15 +97,23 @@ def __init__(self, default=None, required_default=False, required=False, *args,

@classmethod
def schema(cls):
return cls.schema_info(set())

@classmethod
def schema_info(cls, defined_names):
if cls.__name__ in defined_names:
return cls.__name__

defined_names.add(cls.__name__)
schema = {
'name': str(cls.__name__),
'type': 'record',
'fields': []
}

for name in sorted(cls._fields.keys()):
field = cls._fields[name]
field_type = field.schema() if field._required else ['null', field.schema()]
field_type = field.schema_info(defined_names) \
if field._required else ['null', field.schema_info(defined_names)]
schema['fields'].append({
'name': name,
'type': field_type,
Expand Down Expand Up @@ -198,6 +206,9 @@ def schema(self):
# For primitive types, the schema would just be the type itself
return self.type()

def schema_info(self, defined_names):
return self.type()

def default(self):
return self._default

Expand Down Expand Up @@ -347,6 +358,9 @@ def python_type(self):
return self.enum_type

def validate_type(self, name, val):
if val is None:
return None

if type(val) is str:
# The enum was passed as a string, we need to check it against the possible values
if val in self.enum_type.__members__:
Expand All @@ -367,6 +381,12 @@ def validate_type(self, name, val):
return val

def schema(self):
return self.schema_info(set())

def schema_info(self, defined_names):
if self.enum_type.__name__ in defined_names:
return self.enum_type.__name__
defined_names.add(self.enum_type.__name__)
return {
'type': self.type(),
'name': self.enum_type.__name__,
Expand All @@ -393,6 +413,9 @@ def python_type(self):
return list

def validate_type(self, name, val):
if val is None:
return None

super(Array, self).validate_type(name, val)

for x in val:
Expand All @@ -402,9 +425,12 @@ def validate_type(self, name, val):
return val

def schema(self):
return self.schema_info(set())

def schema_info(self, defined_names):
return {
'type': self.type(),
'items': self.array_type.schema() if isinstance(self.array_type, (Array, Map, Record))
'items': self.array_type.schema_info(defined_names) if isinstance(self.array_type, (Array, Map, Record))
else self.array_type.type()
}

Expand All @@ -428,6 +454,9 @@ def python_type(self):
return dict

def validate_type(self, name, val):
if val is None:
return None

super(Map, self).validate_type(name, val)

for k, v in val.items():
Expand All @@ -440,9 +469,12 @@ def validate_type(self, name, val):
return val

def schema(self):
return self.schema_info(set())

def schema_info(self, defined_names):
return {
'type': self.type(),
'values': self.value_type.schema() if isinstance(self.value_type, (Array, Map, Record))
'values': self.value_type.schema_info(defined_names) if isinstance(self.value_type, (Array, Map, Record))
else self.value_type.type()
}

Expand Down
58 changes: 47 additions & 11 deletions pulsar-client-cpp/python/schema_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@
#

from unittest import TestCase, main

import fastavro
import pulsar
from pulsar.schema import *
from enum import Enum
Expand Down Expand Up @@ -46,6 +48,7 @@ class Example(Record):
h = Bytes()
i = Map(String())

fastavro.parse_schema(Example.schema())
self.assertEqual(Example.schema(), {
"name": "Example",
"type": "record",
Expand Down Expand Up @@ -84,6 +87,7 @@ class Example(Record):
sub = MySubRecord # Test with class
sub2 = MySubRecord() # Test with instance

fastavro.parse_schema(Example.schema())
self.assertEqual(Example.schema(), {
"name": "Example",
"type": "record",
Expand All @@ -99,13 +103,7 @@ class Example(Record):
}]
},
{"name": "sub2",
"type": ["null", {
"name": "MySubRecord",
"type": "record",
"fields": [{"name": "x", "type": ["null", "int"]},
{"name": "y", "type": ["null", "long"]},
{"name": "z", "type": ["null", "string"]}]
}]
"type": ["null", 'MySubRecord']
}
]
})
Expand Down Expand Up @@ -896,39 +894,55 @@ class NestedObj4(Record):
na4 = String()
nb4 = Integer()

class Color(Enum):
red = 1
green = 2
blue = 3

class ComplexRecord(Record):
a = Integer()
b = Integer()
color = Color
color2 = Color
nested = NestedObj2()
nested2 = NestedObj2()
mapNested = Map(NestedObj3())
mapNested2 = Map(NestedObj3())
arrayNested = Array(NestedObj4())
arrayNested2 = Array(NestedObj4())

print('complex schema: ', ComplexRecord.schema())
self.assertEqual(ComplexRecord.schema(), {
"name": "ComplexRecord",
"type": "record",
"fields": [
{"name": "a", "type": ["null", "int"]},
{'name': 'arrayNested', 'type': ['null',
{'type': 'array', 'items': {'name': 'NestedObj4', 'type': 'record', 'fields': [
{'name': 'arrayNested', 'type': ['null', {'type': 'array', 'items':
{'name': 'NestedObj4', 'type': 'record', 'fields': [
{'name': 'na4', 'type': ['null', 'string']},
{'name': 'nb4', 'type': ['null', 'int']}
]}}
]},
{'name': 'arrayNested2', 'type': ['null', {'type': 'array', 'items': 'NestedObj4'}]},
{"name": "b", "type": ["null", "int"]},
{'name': 'color', 'type': ['null', {'type': 'enum', 'name': 'Color', 'symbols': [
'red', 'green', 'blue']}]},
{'name': 'color2', 'type': ['null', 'Color']},
{'name': 'mapNested', 'type': ['null', {'type': 'map', 'values':
{'name': 'NestedObj3', 'type': 'record', 'fields': [
{'name': 'na3', 'type': ['null', 'int']}
]}}
]},
{'name': 'mapNested2', 'type': ['null', {'type': 'map', 'values': 'NestedObj3'}]},
{"name": "nested", "type": ['null', {'name': 'NestedObj2', 'type': 'record', 'fields': [
{'name': 'na2', 'type': ['null', 'int']},
{'name': 'nb2', 'type': ['null', 'boolean']},
{'name': 'nc2', 'type': ['null', {'name': 'NestedObj1', 'type': 'record', 'fields': [
{'name': 'na1', 'type': ['null', 'string']},
{'name': 'nb1', 'type': ['null', 'double']}
]}]}
]}]}
]}]},
{"name": "nested2", "type": ['null', 'NestedObj2']}
]
})

Expand All @@ -939,13 +953,22 @@ def encode_and_decode(schema_type):

nested_obj1 = NestedObj1(na1='na1 value', nb1=20.5)
nested_obj2 = NestedObj2(na2=22, nb2=True, nc2=nested_obj1)
r = ComplexRecord(a=1, b=2, nested=nested_obj2, mapNested={
r = ComplexRecord(a=1, b=2, color=Color.red, color2=Color.blue,
nested=nested_obj2, nested2=nested_obj2,
mapNested={
'a': NestedObj3(na3=1),
'b': NestedObj3(na3=2),
'c': NestedObj3(na3=3)
}, mapNested2={
'd': NestedObj3(na3=4),
'e': NestedObj3(na3=5),
'f': NestedObj3(na3=6)
}, arrayNested=[
NestedObj4(na4='value na4 1', nb4=100),
NestedObj4(na4='value na4 2', nb4=200)
], arrayNested2=[
NestedObj4(na4='value na4 3', nb4=300),
NestedObj4(na4='value na4 4', nb4=400)
])
data_encode = data_schema.encode(r)

Expand All @@ -954,17 +977,30 @@ def encode_and_decode(schema_type):
self.assertEqual(data_decode, r)
self.assertEqual(data_decode.a, 1)
self.assertEqual(data_decode.b, 2)
self.assertEqual(data_decode.color, Color.red)
self.assertEqual(data_decode.color2, Color.blue)
self.assertEqual(data_decode.nested.na2, 22)
self.assertEqual(data_decode.nested.nb2, True)
self.assertEqual(data_decode.nested.nc2.na1, 'na1 value')
self.assertEqual(data_decode.nested.nc2.nb1, 20.5)
self.assertEqual(data_decode.nested2.na2, 22)
self.assertEqual(data_decode.nested2.nb2, True)
self.assertEqual(data_decode.nested2.nc2.na1, 'na1 value')
self.assertEqual(data_decode.nested2.nc2.nb1, 20.5)
self.assertEqual(data_decode.mapNested['a'].na3, 1)
self.assertEqual(data_decode.mapNested['b'].na3, 2)
self.assertEqual(data_decode.mapNested['c'].na3, 3)
self.assertEqual(data_decode.mapNested2['d'].na3, 4)
self.assertEqual(data_decode.mapNested2['e'].na3, 5)
self.assertEqual(data_decode.mapNested2['f'].na3, 6)
self.assertEqual(data_decode.arrayNested[0].na4, 'value na4 1')
self.assertEqual(data_decode.arrayNested[0].nb4, 100)
self.assertEqual(data_decode.arrayNested[1].na4, 'value na4 2')
self.assertEqual(data_decode.arrayNested[1].nb4, 200)
self.assertEqual(data_decode.arrayNested2[0].na4, 'value na4 3')
self.assertEqual(data_decode.arrayNested2[0].nb4, 300)
self.assertEqual(data_decode.arrayNested2[1].na4, 'value na4 4')
self.assertEqual(data_decode.arrayNested2[1].nb4, 400)
print('Encode and decode complex schema finish. schema_type: ', schema_type)

encode_and_decode('avro')
Expand Down

0 comments on commit 84a697a

Please sign in to comment.