diff --git a/data_structures/binary_tree/binary_tree_node_sum.py b/data_structures/binary_tree/binary_tree_node_sum.py index be82e7f580b6..5a13e74e3c9f 100644 --- a/data_structures/binary_tree/binary_tree_node_sum.py +++ b/data_structures/binary_tree/binary_tree_node_sum.py @@ -11,6 +11,8 @@ from __future__ import annotations +from collections.abc import Iterator + class Node: """ @@ -33,24 +35,24 @@ class BinaryTreeNodeSum: 12 8 0 >>> tree = Node(10) - >>> BinaryTreeNodeSum(tree).node_sum() + >>> sum(BinaryTreeNodeSum(tree)) 10 >>> tree.left = Node(5) - >>> BinaryTreeNodeSum(tree).node_sum() + >>> sum(BinaryTreeNodeSum(tree)) 15 >>> tree.right = Node(-3) - >>> BinaryTreeNodeSum(tree).node_sum() + >>> sum(BinaryTreeNodeSum(tree)) 12 >>> tree.left.left = Node(12) - >>> BinaryTreeNodeSum(tree).node_sum() + >>> sum(BinaryTreeNodeSum(tree)) 24 >>> tree.right.left = Node(8) >>> tree.right.right = Node(0) - >>> BinaryTreeNodeSum(tree).node_sum() + >>> sum(BinaryTreeNodeSum(tree)) 32 """ @@ -64,8 +66,8 @@ def depth_first_search(self, node: Node | None) -> int: self.depth_first_search(node.left) + self.depth_first_search(node.right) ) - def node_sum(self) -> int: - return self.depth_first_search(self.tree) + def __iter__(self) -> Iterator[int]: + yield self.depth_first_search(self.tree) if __name__ == "__main__":