Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update docstring to comply with coding style and max column length #84

Merged
merged 1 commit into from
Jun 10, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 13 additions & 26 deletions src/converter/pytorch_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,16 +13,14 @@ class PyTorchNodeType(Enum):

class PyTorchNode:
"""
Represents a node in a PyTorch execution trace, initialized based on a
schema version.
Represents a node in a PyTorch execution trace, initialized based on a schema version.

Attributes:
schema (str): Schema version used for initialization.
data_deps (List[PyTorchNode]): List of data-dependent parent nodes.
children (List[PyTorchNode]): List of child nodes.
gpu_children (List[PyTorchNode]): List of GPU-specific child nodes.
record_param_comms_node (Optional[PyTorchNode]): Corresponding
record_param_comms node.
record_param_comms_node (Optional[PyTorchNode]): Corresponding record_param_comms node.
nccl_node (Optional[PyTorchNode]): Corresponding NCCL node.
id (str): Identifier of the node.
name (str): Name of the node.
Expand All @@ -45,10 +43,8 @@ def __init__(self, schema: str, node_data: Dict[str, Any]) -> None:
provided.

Args:
schema (str): The schema version based on which the node will be
initialized.
node_data (Dict[str, Any]): Dictionary containing the data of the
PyTorch node.
schema (str): The schema version based on which the node will be initialized.
node_data (Dict[str, Any]): Dictionary containing the data of the PyTorch node.
"""
self.schema = schema
self.data_deps: List["PyTorchNode"] = []
Expand All @@ -67,12 +63,8 @@ def __repr__(self) -> str:
str: String representation of the node.
"""
return (
f"PyTorchNode("
f"id={self.id}, name={self.name}, "
f"op_type={self.get_op_type()}, "
f"timestamp={self.ts}, "
f"inclusive_duration={self.inclusive_dur}, "
f"exclusive_duration={self.exclusive_dur})"
f"PyTorchNode(id={self.id}, name={self.name}, op_type={self.get_op_type()}, timestamp={self.ts}, "
f"inclusive_duration={self.inclusive_dur}, exclusive_duration={self.exclusive_dur})"
)

def parse_data(self, node_data: Dict[str, Any]) -> None:
Expand All @@ -87,9 +79,8 @@ def parse_data(self, node_data: Dict[str, Any]) -> None:
self._parse_data_1_0_3_chakra_0_0_4(node_data)
else:
raise ValueError(
f"Unsupported schema version '{self.schema}'. Please check "
f"if the schema version is in the list of supported versions: "
f"{self.SUPPORTED_VERSIONS}"
f"Unsupported schema version '{self.schema}'. Please check if the schema version is in the list of "
f"supported versions: {self.SUPPORTED_VERSIONS}"
)

def _parse_data_1_0_3_chakra_0_0_4(self, node_data: Dict[str, Any]) -> None:
Expand Down Expand Up @@ -163,8 +154,7 @@ def add_gpu_child(self, gpu_child_node: "PyTorchNode") -> None:
Adds a child GPU node for this node.

Args:
gpu_child_node (Optional[PyTorchNode]): The child GPU node to be
added.
gpu_child_node (Optional[PyTorchNode]): The child GPU node to be added.
"""
self.gpu_children.append(gpu_child_node)

Expand All @@ -173,8 +163,7 @@ def is_record_param_comms_op(self) -> bool:
Checks if the node is a record_param_comms operator.

Returns:
bool: True if the node is a record_param_comms operator, False
otherwise.
bool: True if the node is a record_param_comms operator, False otherwise.
"""
return "record_param_comms" in self.name

Expand Down Expand Up @@ -246,10 +235,8 @@ def get_data_type_size(data_type: str) -> int:
except KeyError as e:
traceback_str = traceback.format_exc()
raise ValueError(
f"Unsupported data type: {data_type}. The data_type_size_map "
f"dictionary is used for mapping the number of bytes for a "
f"given tensor data type. This dictionary may be incomplete. "
f"Please update the data_type_size_map or report this issue "
f"to the maintainer by creating an issue. Traceback:\n"
f"Unsupported data type: {data_type}. The data_type_size_map dictionary is used for mapping the "
f"number of bytes for a given tensor data type. This dictionary may be incomplete. Please update the "
f"data_type_size_map or report this issue to the maintainer by creating an issue. Traceback:\n"
f"{traceback_str}"
) from e
Loading