diff --git a/stone/backends/python_client.py b/stone/backends/python_client.py index ff49601a..1b4c9ea6 100644 --- a/stone/backends/python_client.py +++ b/stone/backends/python_client.py @@ -108,7 +108,6 @@ class PythonClientBackend(CodeBackend): - # pylint: disable=attribute-defined-outside-init cmdline_parser = _cmdline_parser supported_auth_types = None diff --git a/stone/backends/python_types.py b/stone/backends/python_types.py index 83db416a..c035d824 100644 --- a/stone/backends/python_types.py +++ b/stone/backends/python_types.py @@ -14,10 +14,9 @@ from stone.ir import AnnotationType, ApiNamespace from stone.ir import ( - get_custom_annotations_for_alias, - get_custom_annotations_recursive, is_alias, is_boolean_type, + is_composite_type, is_bytes_type, is_list_type, is_map_type, @@ -642,7 +641,7 @@ def _generate_custom_annotation_processors(self, ns, data_type, extra_annotation dt, _, _ = unwrap(data_type) if is_struct_type(dt) or is_union_type(dt): annotation_types_seen = set() - for annotation in get_custom_annotations_recursive(dt): + for _, annotation in dt.recursive_custom_annotations: if annotation.annotation_type not in annotation_types_seen: yield (annotation.annotation_type, generate_func_call( @@ -672,7 +671,12 @@ def _generate_custom_annotation_processors(self, ns, data_type, extra_annotation # annotations applied directly to this type (through aliases or # passed in from the caller) - for annotation in itertools.chain(get_custom_annotations_for_alias(data_type), + indirect_annotations = dt.recursive_custom_annotations if is_composite_type(dt) else set() + all_annotations = (data_type.recursive_custom_annotations + if is_composite_type(data_type) else set()) + remaining_annotations = [annotation for _, annotation in + all_annotations.difference(indirect_annotations)] + for annotation in itertools.chain(remaining_annotations, extra_annotations): yield (annotation.annotation_type, generate_func_call( diff --git a/stone/frontend/ir_generator.py b/stone/frontend/ir_generator.py index 25503788..74c8d967 100644 --- a/stone/frontend/ir_generator.py +++ b/stone/frontend/ir_generator.py @@ -37,6 +37,7 @@ Int32, Int64, is_alias, + is_composite_type, is_field_type, is_list_type, is_map_type, @@ -297,6 +298,7 @@ def generate_IR(self): self._populate_field_defaults() self._populate_enumerated_subtypes() self._populate_route_attributes() + self._populate_recursive_custom_annotations() self._populate_examples() self._validate_doc_refs() self._validate_annotations() @@ -802,6 +804,75 @@ def _populate_union_type_attributes(self, env, data_type): data_type.set_attributes( data_type._ast_node.doc, api_type_fields, parent_type, catch_all_field) + def _populate_recursive_custom_annotations(self): + """ + Populates custom annotations applied to fields recursively. This is done in + a separate pass because it requires all fields and routes to be defined so that + recursive chains can be followed accurately. + """ + data_types_seen = set() + + def recurse(data_type): + # primitive types do not have annotations + if not is_composite_type(data_type): + return set() + + # if we have already analyzed data type, just return result + if data_type.recursive_custom_annotations is not None: + return data_type.recursive_custom_annotations + + # handle cycles safely (annotations will be found first time at top level) + if data_type in data_types_seen: + return set() + data_types_seen.add(data_type) + + annotations = set() + + # collect data types from subtypes recursively + if is_struct_type(data_type) or is_union_type(data_type): + for field in data_type.fields: + annotations.update(recurse(field.data_type)) + # annotations can be defined directly on fields + annotations.update([(field, annotation) + for annotation in field.custom_annotations]) + elif is_alias(data_type): + annotations.update(recurse(data_type.data_type)) + # annotations can be defined directly on aliases + annotations.update([(data_type, annotation) + for annotation in data_type.custom_annotations]) + elif is_list_type(data_type): + annotations.update(recurse(data_type.data_type)) + elif is_map_type(data_type): + # only map values support annotations for now + annotations.update(recurse(data_type.value_data_type)) + elif is_nullable_type(data_type): + annotations.update(recurse(data_type.data_type)) + + data_type.recursive_custom_annotations = annotations + return annotations + + for namespace in self.api.namespaces.values(): + namespace_annotations = set() + for data_type in namespace.data_types: + namespace_annotations.update(recurse(data_type)) + + for alias in namespace.aliases: + namespace_annotations.update(recurse(alias)) + + for route in namespace.routes: + namespace_annotations.update(recurse(route.arg_data_type)) + namespace_annotations.update(recurse(route.result_data_type)) + namespace_annotations.update(recurse(route.error_data_type)) + + # record annotation types as dependencies of the namespace. this allows for + # an optimization when processing custom annotations to ignore annotation + # types that are not applied to the data type, rather than recursing into it + for _, annotation in namespace_annotations: + if annotation.annotation_type.namespace.name != namespace.name: + namespace.add_imported_namespace( + annotation.annotation_type.namespace, + imported_annotation_type=True) + def _populate_field_defaults(self): """ Populate the defaults of each field. This is done in a separate pass diff --git a/stone/ir/data_types.py b/stone/ir/data_types.py index e4dd5c6f..0bd98dea 100644 --- a/stone/ir/data_types.py +++ b/stone/ir/data_types.py @@ -55,6 +55,30 @@ def generic_type_name(v): return type(v).__name__ +def record_custom_annotation_imports(annotation, namespace): + """ + Records imports for custom annotations in the given namespace. + + """ + # first, check the annotation *type* + if annotation.annotation_type.namespace.name != namespace.name: + namespace.add_imported_namespace( + annotation.annotation_type.namespace, + imported_annotation_type=True) + + # second, check if we need to import the annotation itself + + # the annotation namespace is currently not actually used in the + # backends, which reconstruct the annotation from the annotation + # type directly. This could be changed in the future, and at + # the IR level it makes sense to include the dependency + + if annotation.namespace.name != namespace.name: + namespace.add_imported_namespace( + annotation.namespace, + imported_annotation=True) + + class DataType(object): """ Abstract class representing a data type. @@ -118,6 +142,12 @@ class Composite(DataType): # pylint: disable=abstract-method Composite types are any data type which can be constructed using primitive data types and other composite types. """ + def __init__(self): + super(Composite, self).__init__() + # contains custom annotations that apply to any containing data types (recursively) + # format is (location, CustomAnnotation) to indicate a custom annotation is applied + # to a location (Field or Alias) + self.recursive_custom_annotations = None class Nullable(Composite): @@ -781,22 +811,7 @@ def set_attributes(self, doc, fields, parent_type=None): # they are treated as globals at the IR level for field in self.fields: for annotation in field.custom_annotations: - # first, check the annotation *type* - if annotation.annotation_type.namespace.name != self.namespace.name: - self.namespace.add_imported_namespace( - annotation.annotation_type.namespace, - imported_annotation_type=True) - - # second, check if we need to import the annotation itself - - # the annotation namespace is currently not actually used in the - # backends, which reconstruct the annotation from the annotation - # type directly. This could be changed in the future, and at - # the IR level it makes sense to include the dependency - if annotation.namespace.name != self.namespace.name: - self.namespace.add_imported_namespace( - annotation.namespace, - imported_annotation=True) + record_custom_annotation_imports(annotation, self.namespace) # Indicate that the attributes of the type have been populated. self._is_forward_ref = False @@ -901,7 +916,6 @@ class Struct(UserDefined): """ Defines a product type: Composed of other primitive and/or struct types. """ - # pylint: disable=attribute-defined-outside-init composite_type = 'struct' @@ -1359,7 +1373,6 @@ def __repr__(self): class Union(UserDefined): """Defines a tagged union. Fields are variants.""" - # pylint: disable=attribute-defined-outside-init composite_type = 'union' @@ -1830,25 +1843,7 @@ def set_annotations(self, annotations): elif isinstance(annotation, CustomAnnotation): # Note: we don't need to do this for builtin annotations because # they are treated as globals at the IR level - - # first, check the annotation *type* - if annotation.annotation_type.namespace.name != self.namespace.name: - self.namespace.add_imported_namespace( - annotation.annotation_type.namespace, - imported_annotation_type=True) - - # second, check if we need to import the annotation itself - - # the annotation namespace is currently not actually used in the - # backends, which reconstruct the annotation from the annotation - # type directly. This could be changed in the future, and at - # the IR level it makes sense to include the dependency - - if annotation.namespace.name != self.namespace.name: - self.namespace.add_imported_namespace( - annotation.namespace, - imported_annotation=True) - + record_custom_annotation_imports(annotation, self.namespace) self.custom_annotations.append(annotation) else: raise InvalidSpec("Aliases only support 'Redacted' and custom annotations, not %r" % @@ -2002,53 +1997,6 @@ def unwrap(data_type): data_type = data_type.data_type return data_type, unwrapped_nullable, unwrapped_alias -def get_custom_annotations_for_alias(data_type): - """ - Given a Stone data type, returns all custom annotations applied to it. - """ - # annotations can only be applied to Aliases, but they can be wrapped in - # Nullable. also, Aliases pointing to other Aliases don't automatically - # inherit their custom annotations, so we might have to traverse. - result = [] - data_type, _ = unwrap_nullable(data_type) - while is_alias(data_type): - result.extend(data_type.custom_annotations) - data_type, _ = unwrap_nullable(data_type.data_type) - return result - -def get_custom_annotations_recursive(data_type): - """ - Given a Stone data type, returns all custom annotations applied to any of - its memebers, as well as submembers, ..., to an arbitrary depth. - """ - # because Stone structs can contain references to themselves (or otherwise - # be cyclical), we need ot keep track of the data types we've already seen - data_types_seen = set() - - def recurse(data_type): - if data_type in data_types_seen: - return - data_types_seen.add(data_type) - - dt, _, _ = unwrap(data_type) - if is_struct_type(dt) or is_union_type(dt): - for field in dt.fields: - for annotation in recurse(field.data_type): - yield annotation - for annotation in field.custom_annotations: - yield annotation - elif is_list_type(dt): - for annotation in recurse(dt.data_type): - yield annotation - elif is_map_type(dt): - for annotation in recurse(dt.value_data_type): - yield annotation - - for annotation in get_custom_annotations_for_alias(data_type): - yield annotation - - return recurse(data_type) - def is_alias(data_type): return isinstance(data_type, Alias) diff --git a/test/test_python_gen.py b/test/test_python_gen.py index 1562e73b..a93162b2 100755 --- a/test/test_python_gen.py +++ b/test/test_python_gen.py @@ -229,7 +229,6 @@ def test_json_encoder(self): self.assertEqual(json_encode(bv.Nullable(bv.String()), u'abc'), json.dumps('abc')) def test_json_encoder_union(self): - # pylint: disable=attribute-defined-outside-init class S(object): _all_field_names_ = {'f'} _all_fields_ = [('f', bv.String())] @@ -331,7 +330,6 @@ def _get_val_data_type(cls, tag, cp): self.assertEqual(json_encode(bv.Union(U), u, old_style=True), json.dumps({'g': m})) def test_json_encoder_error_messages(self): - # pylint: disable=attribute-defined-outside-init class S3(object): _all_field_names_ = {'j'} _all_fields_ = [('j', bv.UInt64(max_value=10))] diff --git a/test/test_python_types.py b/test/test_python_types.py index 89800e61..c5972683 100644 --- a/test/test_python_types.py +++ b/test/test_python_types.py @@ -171,6 +171,7 @@ def test_struct_with_custom_annotations(self): StructField('unannotated_field', Int32(), None, None), ]) struct.fields[0].set_annotations([annotation]) + struct.recursive_custom_annotations = set([annotation]) result = self._evaluate_struct(ns, struct) diff --git a/test/test_stone.py b/test/test_stone.py index 61b3b79e..b3592d58 100755 --- a/test/test_stone.py +++ b/test/test_stone.py @@ -4886,6 +4886,52 @@ def test_custom_annotations(self): struct = api.namespaces['test'].data_type_by_name['TestStruct'] self.assertEqual(struct.fields[0].custom_annotations[0], annotation) + self.assertEqual(struct.recursive_custom_annotations, set([ + (alias, api.namespaces['test'].annotation_by_name['VeryImportant']), + (struct.fields[0], api.namespaces['test'].annotation_by_name['SortaImportant']), + ])) + + # Test recursive references are captured + ns2 = textwrap.dedent("""\ + namespace testchain + + import test + + alias TestAliasChain = String + @test.SortaImportant + + struct TestStructChain + f test.TestStruct + g List(TestAliasChain) + """) + ns3 = textwrap.dedent("""\ + namespace teststruct + + import testchain + + struct TestStructToStruct + f testchain.TestStructChain + """) + ns4 = textwrap.dedent("""\ + namespace testalias + + import testchain + + struct TestStructToAlias + f testchain.TestAliasChain + """) + + api = specs_to_ir([('test.stone', text), ('testchain.stone', ns2), + ('teststruct.stone', ns3), ('testalias.stone', ns4)]) + + struct_namespaces = [ns.name for ns in + api.namespaces['teststruct'].get_imported_namespaces( + consider_annotation_types=True)] + self.assertTrue('test' in struct_namespaces) + alias_namespaces = [ns.name for ns in + api.namespaces['testalias'].get_imported_namespaces( + consider_annotation_types=True)] + self.assertTrue('test' in alias_namespaces) if __name__ == '__main__':