diff --git a/src/converter/pytorch_node.py b/src/converter/pytorch_node.py index 16c77c42..f4a67bdc 100644 --- a/src/converter/pytorch_node.py +++ b/src/converter/pytorch_node.py @@ -13,16 +13,14 @@ class PyTorchNodeType(Enum): class PyTorchNode: """ - Represents a node in a PyTorch execution trace, initialized based on a - schema version. + Represents a node in a PyTorch execution trace, initialized based on a schema version. Attributes: schema (str): Schema version used for initialization. data_deps (List[PyTorchNode]): List of data-dependent parent nodes. children (List[PyTorchNode]): List of child nodes. gpu_children (List[PyTorchNode]): List of GPU-specific child nodes. - record_param_comms_node (Optional[PyTorchNode]): Corresponding - record_param_comms node. + record_param_comms_node (Optional[PyTorchNode]): Corresponding record_param_comms node. nccl_node (Optional[PyTorchNode]): Corresponding NCCL node. id (str): Identifier of the node. name (str): Name of the node. @@ -45,10 +43,8 @@ def __init__(self, schema: str, node_data: Dict[str, Any]) -> None: provided. Args: - schema (str): The schema version based on which the node will be - initialized. - node_data (Dict[str, Any]): Dictionary containing the data of the - PyTorch node. + schema (str): The schema version based on which the node will be initialized. + node_data (Dict[str, Any]): Dictionary containing the data of the PyTorch node. """ self.schema = schema self.data_deps: List["PyTorchNode"] = [] @@ -67,12 +63,8 @@ def __repr__(self) -> str: str: String representation of the node. """ return ( - f"PyTorchNode(" - f"id={self.id}, name={self.name}, " - f"op_type={self.get_op_type()}, " - f"timestamp={self.ts}, " - f"inclusive_duration={self.inclusive_dur}, " - f"exclusive_duration={self.exclusive_dur})" + f"PyTorchNode(id={self.id}, name={self.name}, op_type={self.get_op_type()}, timestamp={self.ts}, " + f"inclusive_duration={self.inclusive_dur}, exclusive_duration={self.exclusive_dur})" ) def parse_data(self, node_data: Dict[str, Any]) -> None: @@ -87,9 +79,8 @@ def parse_data(self, node_data: Dict[str, Any]) -> None: self._parse_data_1_0_3_chakra_0_0_4(node_data) else: raise ValueError( - f"Unsupported schema version '{self.schema}'. Please check " - f"if the schema version is in the list of supported versions: " - f"{self.SUPPORTED_VERSIONS}" + f"Unsupported schema version '{self.schema}'. Please check if the schema version is in the list of " + f"supported versions: {self.SUPPORTED_VERSIONS}" ) def _parse_data_1_0_3_chakra_0_0_4(self, node_data: Dict[str, Any]) -> None: @@ -163,8 +154,7 @@ def add_gpu_child(self, gpu_child_node: "PyTorchNode") -> None: Adds a child GPU node for this node. Args: - gpu_child_node (Optional[PyTorchNode]): The child GPU node to be - added. + gpu_child_node (Optional[PyTorchNode]): The child GPU node to be added. """ self.gpu_children.append(gpu_child_node) @@ -173,8 +163,7 @@ def is_record_param_comms_op(self) -> bool: Checks if the node is a record_param_comms operator. Returns: - bool: True if the node is a record_param_comms operator, False - otherwise. + bool: True if the node is a record_param_comms operator, False otherwise. """ return "record_param_comms" in self.name @@ -246,10 +235,8 @@ def get_data_type_size(data_type: str) -> int: except KeyError as e: traceback_str = traceback.format_exc() raise ValueError( - f"Unsupported data type: {data_type}. The data_type_size_map " - f"dictionary is used for mapping the number of bytes for a " - f"given tensor data type. This dictionary may be incomplete. " - f"Please update the data_type_size_map or report this issue " - f"to the maintainer by creating an issue. Traceback:\n" + f"Unsupported data type: {data_type}. The data_type_size_map dictionary is used for mapping the " + f"number of bytes for a given tensor data type. This dictionary may be incomplete. Please update the " + f"data_type_size_map or report this issue to the maintainer by creating an issue. Traceback:\n" f"{traceback_str}" ) from e