From da5bcb8b6cc7a0db55efba0e98cf2a5272553f06 Mon Sep 17 00:00:00 2001 From: Taekyung Heo <7621438+TaekyungHeo@users.noreply.github.com> Date: Wed, 10 Jan 2024 20:49:00 -0500 Subject: [PATCH] Refactor --- et_converter/pytorch2chakra_converter.py | 399 +++++++++++++---------- et_converter/pytorch_tensor.py | 1 + 2 files changed, 232 insertions(+), 168 deletions(-) diff --git a/et_converter/pytorch2chakra_converter.py b/et_converter/pytorch2chakra_converter.py index ea86ad24..dbe2c32c 100644 --- a/et_converter/pytorch2chakra_converter.py +++ b/et_converter/pytorch2chakra_converter.py @@ -1,21 +1,13 @@ #!/usr/bin/env python3 - import bisect import copy import json import logging - from typing import Any, Dict, List, Optional, Tuple from chakra.third_party.utils.protolib import encodeMessage as encode_message -from chakra.et_converter.pytorch_node import ( - PyTorchNodeType, - PyTorchNode -) -from chakra.et_converter.pytorch_tensor import ( - PyTorchTensor, - list_to_pytorch_tensor -) +from chakra.et_converter.pytorch_node import PyTorchNodeType, PyTorchNode +from chakra.et_converter.pytorch_tensor import PyTorchTensor, list_to_pytorch_tensor from chakra.et_def.et_def_pb2 import ( GlobalMetadata, Node as ChakraNode, @@ -32,14 +24,22 @@ class UniqueIdAssigner: + """ + Class for assigning unique IDs to original IDs. + + Attributes: + next_id (int): The next available unique ID. + original_to_new_ids (Dict[int, int]): Mapping from original IDs to unique IDs. + """ + def __init__(self) -> None: self.next_id = 0 self.original_to_new_ids: Dict[int, int] = {} def assign_unique_id(self, original_id: int) -> int: """ - Assigns a new unique ID to the given original ID if it doesn't have one already; - otherwise, returns the previously assigned unique ID. + Assigns a new unique ID to the given original ID if it doesn't have one + already; otherwise, returns the previously assigned unique ID. Args: original_id (int): The original ID for which a unique ID is needed. @@ -55,6 +55,36 @@ def assign_unique_id(self, original_id: int) -> int: class PyTorch2ChakraConverter: + """ + Converter class for transforming PyTorch execution traces into Chakra format. + + This class is responsible for converting the execution traces collected + from PyTorch into a format that is compatible with Chakra, a performance + analysis tool. It handles the intricate mappings and transformations + required to accurately represent the execution in a different format. + + Attributes: + input_filename (str): Input file name containing PyTorch execution trace. + output_filename (str): Output file name for the converted Chakra trace. + num_dims (int): Number of dimensions involved in the conversion process. + logger (logging.Logger): Logger for logging information during conversion. + id_assigner (UniqueIdAssigner): Object to manage unique ID assignments. + pytorch_schema (Optional[str]): Schema info of the PyTorch trace. + pytorch_pid (Optional[int]): Process ID associated with the PyTorch trace. + pytorch_time (Optional[str]): Time info of the PyTorch trace. + pytorch_start_ts (Optional[int]): Start timestamp of the PyTorch trace. + pytorch_finish_ts (Optional[int]): Finish timestamp of the PyTorch trace. + pytorch_nodes (Dict[int, Any]): Map of PyTorch node IDs to nodes. + pytorch_root_nids (List[int]): List of root node IDs in the PyTorch trace. + pytorch_cpu_node_id_gpu_node_map (Dict[int, List[int]]): Map of PyTorch CPU node IDs to GPU node IDs. + chakra_nodes (Dict[int, Any]): Map of Chakra node IDs to nodes. + phase_end_node_ids (List[int]): List of node IDs for phase dependencies. + input_storage_id_nid_map (Dict[int, int]): Map of input storage IDs to node IDs. + output_storage_id_nid_map (Dict[int, int]): Map of output storage IDs to node IDs. + input_tensor_id_nid_map (Dict[int, int]): Map of input tensor IDs to node IDs. + output_tensor_id_nid_map (Dict[int, int]): Map of output tensor IDs to node IDs. + """ + def __init__( self, input_filename: str, @@ -63,33 +93,14 @@ def __init__( logger: logging.Logger ) -> None: """ - Initializes the PyTorch to Chakra converter. - - The converter also identifies data dependencies between PyTorch nodes using - storage IDs and tensor IDs. These dependencies are crucial for accurate - reconstruction of the execution graph in the Chakra format. + Initializes the PyTorch to Chakra converter. It sets up necessary + attributes and prepares the environment for the conversion process. Args: input_filename (str): Name of the input file containing PyTorch execution trace. output_filename (str): Name of the output file for the converted Chakra trace. num_dims (int): Number of dimensions involved in the conversion process. logger (logging.Logger): Logger for logging information during the conversion. - - Attributes: - pytorch_schema (Optional[str]): Schema information of the PyTorch trace. - pytorch_pid (Optional[int]): Process ID associated with the PyTorch trace. - pytorch_time (Optional[str]): Time information of the PyTorch trace. - pytorch_start_ts (Optional[int]): Start timestamp of the PyTorch trace. - pytorch_finish_ts (Optional[int]): Finish timestamp of the PyTorch trace. - pytorch_nodes (Dict[int, Any]): Map of PyTorch node IDs to nodes. - pytorch_root_nids (List[int]): List of root node IDs in the PyTorch trace. - pytorch_cpu_node_id_gpu_node_map (Dict[int, List[int]]): Map of PyTorch CPU node IDs to corresponding GPU node IDs. - chakra_nodes (Dict[int, Any]): Map of Chakra node IDs to nodes. - phase_end_node_ids (List[int]): List of node IDs to enforce phase dependencies. - input_storage_id_nid_map (Dict[int, int]): Map of input storage IDs to node IDs. - output_storage_id_nid_map (Dict[int, int]): Map of output storage IDs to node IDs. - input_tensor_id_nid_map (Dict[int, int]): Map of input tensor IDs to node IDs. - output_tensor_id_nid_map (Dict[int, int]): Map of output tensor IDs to node IDs. """ self.input_filename = input_filename self.output_filename = output_filename @@ -108,34 +119,40 @@ def __init__( self.pytorch_root_nids = [] # Initialize node mapping dictionaries - # TODO: Maybe we don't need this self.pytorch_cpu_node_id_gpu_node_map = {} self.chakra_nodes = {} # Initialize lists for phase dependencies and data dependency maps self.phase_end_node_ids = [] - # Data dependency maps: - # These maps are used to establish relationships between nodes based on - # tensor data flow. They track which nodes are producing or consuming tensors. - # Map of input storage IDs to node IDs: - # Tracks which nodes are consuming tensors based on their storage ID. + # This dictionary tracks which nodes are consuming tensors based on their + # storage ID, establishing a link between tensor storage and node consumption. self.input_storage_id_nid_map = {} # Map of output storage IDs to node IDs: - # Tracks which nodes are producing tensors based on their storage ID. + # Similar to input_storage_id_nid_map, but this tracks the production of + # tensors by nodes, associating tensor storage IDs with the nodes that + # produce them. self.output_storage_id_nid_map = {} # Map of input tensor IDs to node IDs: - # Similar to storage IDs, but using tensor IDs for tensors without a valid storage ID. + # This dictionary is used when storage IDs are not applicable. It tracks + # which nodes are consuming tensors by using tensor IDs, creating a link + # between tensor IDs and the nodes that consume them. self.input_tensor_id_nid_map = {} # Map of output tensor IDs to node IDs: - # Used for tracking tensor outputs when storage IDs are not applicable. - self.output_tensor_id_nid_map = {} + # Similar to input_tensor_id_nid_map, but for tracking the output of tensors + # from nodes. It associates tensor IDs with the nodes that output them, + # used when storage IDs are not available. - def convert(self) -> None: + def convert(self) -> None: + """ + Converts PyTorch execution traces into the Chakra format. Orchestrates + the conversion process including trace loading, trace opening, phase + end node construction, node splitting, and node conversion. + """ self.load_pytorch_execution_traces() self.open_chakra_execution_trace() @@ -144,34 +161,34 @@ def convert(self) -> None: self.split_cpu_nodes_with_gpu_child() - self.logger.info("Convert PyTorch nodes to Chakra nodes") for pytorch_nid, pytorch_node in self.pytorch_nodes.items(): if pytorch_node.is_cpu_op(): self.update_input_tensor_map(pytorch_node.id, pytorch_node.inputs) self.update_output_tensor_map(pytorch_node.id, pytorch_node.outputs) + if pytorch_node.child_gpu: pytorch_gpu_node = pytorch_node.child_gpu self.update_input_tensor_map(pytorch_gpu_node.id, pytorch_gpu_node.inputs) - # For now we ignore GPU->CPU dependencies since it creates unwanted dependencies. - # self.update_output_tensor_map(pytorch_gpu_node["id"], pytorch_gpu_node["outputs"]) + # Ignoring GPU->CPU dependencies for now since it creates unwanted dependencies. chakra_node = self.convert_to_chakra_node(pytorch_node) self.chakra_nodes[chakra_node.id] = chakra_node + if pytorch_node.child_gpu: pytorch_gpu_node = pytorch_node.child_gpu chakra_gpu_node = self.convert_to_chakra_node(pytorch_gpu_node) + if chakra_node.type == COMM_COLL_NODE: pytorch_nccl_node = self.get_nccl_node(pytorch_node) - chakra_gpu_node.attr.append( - ChakraAttr(name="comm_type", - int64_val=pytorch_nccl_node.collective_comm_type)) - chakra_gpu_node.attr.append( - ChakraAttr(name="comm_size", - int64_val=pytorch_nccl_node.comm_size)) - attr = ChakraAttr(name="involved_dim") - for _ in range(self.num_dims): - attr.bool_list.values.append(True) - chakra_gpu_node.attr.append(attr) + chakra_gpu_node.attr.extend([ + ChakraAttr(name="comm_type", + int64_val=pytorch_nccl_node.collective_comm_type), + ChakraAttr(name="comm_size", + int64_val=pytorch_nccl_node.comm_size), + ChakraAttr(name="involved_dim", + bool_list={"values": [True]*self.num_dims}) + ]) + chakra_gpu_node.data_deps.append(chakra_node.id) self.chakra_nodes[chakra_gpu_node.id] = chakra_gpu_node @@ -194,27 +211,32 @@ def load_pytorch_execution_traces(self) -> None: """ Loads PyTorch execution traces from a file. - This function reads the PyTorch execution trace data from a file, parses - it, and instantiates PyTorchNode objects. It also establishes parent-child - relationships between nodes and identifies root nodes. + Reads and parses the PyTorch execution trace data from a file, creating + PyTorchNode objects and establishing node relationships. Raises: Exception: If there is an IOError in opening the file. """ + self.logger.info("Loading PyTorch execution traces from file.") try: with open(self.input_filename, "r") as pytorch_et: pytorch_et_data = json.load(pytorch_et) self._parse_and_instantiate_nodes(pytorch_et_data) except IOError as e: + self.logger.error(f"Error opening file {self.input_filename}: {e}") raise Exception(f"Could not open file {self.input_filename}") def _parse_and_instantiate_nodes(self, pytorch_et_data: Dict) -> None: """ - Parses and instantiates PyTorch nodes. + Parses and instantiates PyTorch nodes from execution trace data. Args: pytorch_et_data (Dict): The execution trace data. + + Extracts node information, sorts nodes by timestamp, and establishes + parent-child relationships among them. """ + self.logger.info("Extracting and processing node data from execution trace.") self.pytorch_schema = pytorch_et_data["schema"] self.pytorch_pid = pytorch_et_data["pid"] self.pytorch_time = pytorch_et_data["time"] @@ -230,15 +252,28 @@ def _parse_and_instantiate_nodes(self, pytorch_et_data: Dict) -> None: self._establish_parent_child_relationships(pytorch_node_objects) + self.pytorch_nodes = pytorch_node_objects + def _establish_parent_child_relationships( self, pytorch_node_objects: Dict[int, PyTorchNode] ) -> None: """ - Establishes parent-child relationships among PyTorch nodes. + Establishes parent-child relationships among PyTorch nodes and counts the node types. Args: - pytorch_node_objects (Dict[int, PyTorchNode]): Dictionary of node objects. - """ + pytorch_node_objects (Dict[int, PyTorchNode]): Dictionary of PyTorch node objects. + """ + # Initialize counters for different types of nodes + node_type_counts = { + "total_op": 0, + "cpu_op": 0, + "gpu_op": 0, + "record_param_comms_op": 0, + "nccl_op": 0, + "root_op": 0 + } + + # Establish parent-child relationships for pytorch_node in pytorch_node_objects.values(): parent_id = pytorch_node.parent if parent_id in pytorch_node_objects: @@ -257,6 +292,22 @@ def _establish_parent_child_relationships( if pytorch_node.name in ["[pytorch|profiler|execution_graph|thread]", "[pytorch|profiler|execution_trace|thread]"]: self.pytorch_root_nids.append(pytorch_node.id) + node_type_counts["root_op"] += 1 + + # Collect statistics + node_type_counts["total_op"] += 1 + if pytorch_node.is_cpu_op(): + node_type_counts["cpu_op"] += 1 + if pytorch_node.is_gpu_op(): + node_type_counts["gpu_op"] += 1 + if pytorch_node.is_record_param_comms_op(): + node_type_counts["record_param_comms_op"] += 1 + if pytorch_node.is_nccl_op(): + node_type_counts["nccl_op"] += 1 + + # Log the counts of each node type + for node_type, count in node_type_counts.items(): + self.logger.info(f"{node_type}: {count}") self.pytorch_nodes = pytorch_node_objects @@ -267,19 +318,21 @@ def open_chakra_execution_trace(self) -> None: Raises: Exception: If there is an IOError in opening the file. """ + self.logger.info(f"Opening Chakra execution trace file: {self.output_filename}") try: self.chakra_et = open(self.output_filename, "wb") except IOError as e: + self.logger.error(f"Error opening file {self.output_filename}: {e}") raise Exception(f"Could not open file {self.output_filename}") def construct_phase_end_node_ids(self) -> None: """ Identifies the dependencies between phases in the execution trace. - This method uses a depth-first search (DFS) approach starting from phase - root nodes to find the largest Node ID (NID) in each phase. These NIDs - are used to track dependencies between different phases. + Uses a depth-first search (DFS) approach starting from phase root nodes to find + the largest Node ID (NID) in each phase for dependency tracking. """ + self.logger.info("Constructing phase end node IDs.") for node in self.pytorch_nodes.values(): if self.is_phase_root_op(node): largest_nid_within_phase = self.dfs(node) @@ -303,8 +356,7 @@ def dfs(self, node: PyTorchNode) -> int: """ Performs a depth-first search to find the largest Node ID (NID) in a subtree. - This method explores the subtree of the given node to find the largest - NID, considering only the CPU operation nodes. + Explores the subtree of the given node to find the largest NID among CPU operation nodes. Args: node (PyTorchNode): The node from which the search starts. @@ -322,19 +374,18 @@ def dfs(self, node: PyTorchNode) -> int: largest_nid = max(largest_nid, self.dfs(child_node)) return largest_nid + self.pytorch_nodes = updated_pytorch_nodes + def split_cpu_nodes_with_gpu_child(self) -> None: """ - Decomposes CPU nodes with GPU child nodes into multiple sub-nodes for concurrent execution. + Decomposes CPU nodes with GPU child nodes into multiple sub-nodes. - This function splits each CPU node into two parts if it has a GPU child: - one part runs until the start of the GPU node, and the other resumes after - the GPU node completes. + Splits each CPU node into two parts if it has a GPU child for accurate execution mapping. Raises: ValueError: If timestamps of GPU and CPU nodes are inconsistent. """ - self.logger.info("Decomposing CPU nodes with GPU child nodes") - + self.logger.info("Decomposing CPU nodes with GPU child nodes.") updated_pytorch_nodes: Dict[int, PyTorchNode] = {} for cpu_node in self.pytorch_nodes.values(): if cpu_node.child_gpu is None: @@ -344,21 +395,14 @@ def split_cpu_nodes_with_gpu_child(self) -> None: else: gpu_node = cpu_node.child_gpu if gpu_node.ts >= (cpu_node.ts + cpu_node.dur): - raise ValueError( - f"Inconsistent timestamps for CPU node {cpu_node.id} " - "and its GPU child" - ) - - # Splitting the CPU node and updating with GPU node - cpu_node_first, cpu_node_second, updated_gpu_node = \ - self._split_cpu_node(cpu_node, gpu_node) + self.logger.error(f"Inconsistent timestamps for CPU node {cpu_node.id} and GPU child") + raise ValueError(f"Inconsistent timestamps for CPU node {cpu_node.id} and its GPU child") + + cpu_node_first, cpu_node_second, updated_gpu_node =\ + self._split_cpu_node(cpu_node, gpu_node) updated_pytorch_nodes[cpu_node_first.id] = cpu_node_first updated_pytorch_nodes[cpu_node_second.id] = cpu_node_second updated_pytorch_nodes[updated_gpu_node.id] = updated_gpu_node - self.logger.debug( - f"CPU node {cpu_node.id} split into {cpu_node_first.id} " - f"and {cpu_node_second.id}, with GPU child {updated_gpu_node.id}" - ) self.pytorch_nodes = updated_pytorch_nodes @@ -366,7 +410,7 @@ def _split_cpu_node( self, cpu_node: PyTorchNode, gpu_node: PyTorchNode ) -> Tuple[PyTorchNode, PyTorchNode, PyTorchNode]: """ - Splits a CPU node based on the GPU node's timestamp and returns the updated GPU node. + Splits a CPU node based on the GPU node's timestamp. Args: cpu_node (PyTorchNode): Original CPU node to be split. @@ -378,22 +422,18 @@ def _split_cpu_node( Raises: ValueError: For inconsistencies in the timestamps of the nodes. """ - # First half of the CPU node cpu_node_first = copy.deepcopy(cpu_node) cpu_node_first.id = self.id_assigner.assign_unique_id(cpu_node.id) cpu_node_first.ts = cpu_node.ts cpu_node_first.dur = gpu_node.ts - cpu_node.ts cpu_node_first.set_child_gpu = gpu_node if cpu_node_first.ts >= gpu_node.ts or cpu_node_first.dur <= 0: - raise ValueError( - f"Invalid timestamps for the first split of CPU node {cpu_node.id}" - ) + self.logger.error(f"Invalid timestamps for the first split of CPU node {cpu_node.id}") + raise ValueError(f"Invalid timestamps for the first split of CPU node {cpu_node.id}") - # Updating GPU node gpu_node_id = self.id_assigner.assign_unique_id(gpu_node.id) gpu_node.id = gpu_node_id - # Second half of the CPU node cpu_node_second = copy.deepcopy(cpu_node) cpu_node_second.id = self.id_assigner.assign_unique_id(cpu_node.id) cpu_node_second.ts = gpu_node.ts @@ -401,9 +441,8 @@ def _split_cpu_node( cpu_node_second.set_child_gpu(None) cpu_node_second.add_data_dep(cpu_node_first) if cpu_node_second.ts <= cpu_node_first.ts or cpu_node_second.dur <= 0: - raise ValueError( - f"Invalid timestamps for the second split of CPU node {cpu_node.id}" - ) + self.logger.error(f"Invalid timestamps for the second split of CPU node {cpu_node.id}") + raise ValueError(f"Invalid timestamps for the second split of CPU node {cpu_node.id}") return cpu_node_first, cpu_node_second, gpu_node @@ -415,6 +454,7 @@ def update_input_tensor_map(self, nid: int, inputs: List[List[int]]) -> None: nid (int): Node ID associated with the input tensors. inputs (List[List[int]]): List of input tensor data. """ + self.logger.debug(f"Updating input tensor map for node ID {nid}.") self.update_tensor_map(nid, inputs, is_input=True) def update_output_tensor_map(self, nid: int, outputs: List[List[int]]) -> None: @@ -425,6 +465,7 @@ def update_output_tensor_map(self, nid: int, outputs: List[List[int]]) -> None: nid (int): Node ID associated with the output tensors. outputs (List[List[int]]): List of output tensor data. """ + self.logger.debug(f"Updating output tensor map for node ID {nid}.") self.update_tensor_map(nid, outputs, is_input=False) def update_tensor_map(self, nid: int, tensors: List[List[int]], is_input: bool) -> None: @@ -437,6 +478,7 @@ def update_tensor_map(self, nid: int, tensors: List[List[int]], is_input: bool) is_input (bool): Flag indicating if the tensors are inputs. """ tensor_map = self.input_storage_id_nid_map if is_input else self.output_storage_id_nid_map + self.logger.debug(f"Updating {'input' if is_input else 'output'} tensor map for node ID {nid}.") tensor_objects = [list_to_pytorch_tensor(tensor_list) for tensor_list in tensors] for tensor in tensor_objects: @@ -452,61 +494,55 @@ def _update_tensor_map(self, tensor_map: Dict[int, List[int]], tensor: PyTorchTe tensor (PyTorchTensor): The tensor object to use for the update. nid (int): The node ID to associate with the tensor. """ - if tensor.has_valid_storage_id(): - key = tensor.storage_id - else: - key = tensor.tensor_id + key = tensor.storage_id if tensor.has_valid_storage_id() else tensor.tensor_id tensor_map.setdefault(key, []).append(nid) def convert_to_chakra_node(self, pytorch_node: PyTorchNode) -> ChakraNode: """ - Converts this PyTorchNode to a ChakraNode. + Converts a PyTorchNode to a ChakraNode. + + Args: + pytorch_node (PyTorchNode): The PyTorch node to convert. Returns: - ChakraNode: The converted ChakraNode. + ChakraNode: The converted Chakra node. """ + self.logger.debug(f"Converting PyTorch node ID {pytorch_node.id} to Chakra node.") + chakra_node = ChakraNode() chakra_node.id = pytorch_node.id chakra_node.name = pytorch_node.name chakra_node.type = self.get_chakra_node_type_from_pytorch_node(pytorch_node) chakra_node.ctrl_deps.append(pytorch_node.parent) - if pytorch_node.has_dur(): - chakra_node.duration_micros = pytorch_node.dur - else: - chakra_node.duration_micros = 0 + chakra_node.duration_micros = pytorch_node.dur if pytorch_node.has_dur() else 0 chakra_node.inputs.values = str(pytorch_node.inputs) chakra_node.inputs.shapes = str(pytorch_node.input_shapes) chakra_node.inputs.types = str(pytorch_node.input_types) chakra_node.outputs.values = str(pytorch_node.outputs) chakra_node.outputs.shapes = str(pytorch_node.output_shapes) chakra_node.outputs.types = str(pytorch_node.output_types) - chakra_node.attr.append( - ChakraAttr(name="rf_id", - int64_val=pytorch_node.rf_id)) - chakra_node.attr.append( - ChakraAttr(name="fw_parent", - int64_val=pytorch_node.fw_parent)) - chakra_node.attr.append( - ChakraAttr(name="seq_id", - int64_val=pytorch_node.seq_id)) - chakra_node.attr.append( - ChakraAttr(name="scope", - int64_val=pytorch_node.scope)) - chakra_node.attr.append( - ChakraAttr(name="tid", - int64_val=pytorch_node.tid)) - chakra_node.attr.append( - ChakraAttr(name="fw_tid", - int64_val=pytorch_node.fw_tid)) - chakra_node.attr.append( - ChakraAttr(name="op_schema", - string_val=pytorch_node.op_schema)) - chakra_node.attr.append( - ChakraAttr(name="is_cpu_node", - bool_val=pytorch_node.is_cpu_op())) + chakra_node.attr.extend([ + ChakraAttr(name="rf_id", int64_val=pytorch_node.rf_id), + ChakraAttr(name="fw_parent", int64_val=pytorch_node.fw_parent), + ChakraAttr(name="seq_id", int64_val=pytorch_node.seq_id), + ChakraAttr(name="scope", int64_val=pytorch_node.scope), + ChakraAttr(name="tid", int64_val=pytorch_node.tid), + ChakraAttr(name="fw_tid", int64_val=pytorch_node.fw_tid), + ChakraAttr(name="op_schema", string_val=pytorch_node.op_schema), + ChakraAttr(name="is_cpu_node", bool_val=pytorch_node.is_cpu_op()) + ]) return chakra_node def get_chakra_node_type_from_pytorch_node(self, pytorch_node: PyTorchNode) -> int: + """ + Determines the Chakra node type from a PyTorch node. + + Args: + pytorch_node (PyTorchNode): The PyTorch node to determine the type of. + + Returns: + int: The corresponding Chakra node type. + """ if pytorch_node.is_gpu_op() and ("ncclKernel" in pytorch_node.name): return COMM_COLL_NODE elif pytorch_node.is_gpu_op(): @@ -521,27 +557,32 @@ def get_nccl_node(self, node: PyTorchNode) -> PyTorchNode: """ Returns a PyTorch NCCL node for a given Chakra CPU node. - For communication nodes, finding a corresponding NCCL node is critical - to identify the communication type and communication size. + Critical for identifying communication type and size in communication nodes. + There are two primary cases to consider: when the given node is a parent + of a record_param_comms node or a NCCL node. - There are two cases: - (1) Given node is a parent of a record_param_comms node - * In this case, the corresponding NCCL node should be a child of - the record_param_comms_pt node. - (2) Given node is a parent of a NCCL node + Args: + node (PyTorchNode): The parent node for which the NCCL node is needed. + + Returns: + PyTorchNode: The corresponding NCCL node. + + Raises: + ValueError: If no corresponding NCCL node is found. """ - nccl_node = None + self.logger.debug(f"Retrieving NCCL node for PyTorch node ID {node.id}.") if node.record_param_comms_node: record_param_comms_node = node.record_param_comms_node if record_param_comms_node.nccl_node: - nccl_node = record_param_comms_node.nccl_node + return record_param_comms_node.nccl_node else: - raise ValueError("") + self.logger.error("No NCCL node found in the record_param_comms node.") + raise ValueError("No corresponding NCCL node found in the record_param_comms node.") elif node.nccl_node: - nccl_node = node.nccl_node + return node.nccl_node else: - raise ValueError("") - return nccl_node + self.logger.error("No NCCL node associated with the given PyTorch node.") + raise ValueError("No corresponding NCCL node found.") # TODO def get_prev_inter_phase_dep_nid(self, node: ChakraNode) -> int: @@ -549,7 +590,15 @@ def get_prev_inter_phase_dep_nid(self, node: ChakraNode) -> int: Returns the NID of the latest node of the previous phase. Finds the closest but smaller value from phase_end_node_ids compared to node.id. + This is used to determine the dependencies between different phases in the trace. + + Args: + node (ChakraNode): The node to find the previous phase dependency for. + + Returns: + int: NID of the latest node of the previous phase, or -1 if none. """ + self.logger.debug(f"Finding previous inter-phase dependency for node ID {node.id}.") index = bisect.bisect_left(self.phase_end_node_ids, node.id) if index == 0: @@ -561,38 +610,41 @@ def get_prev_inter_phase_dep_nid(self, node: ChakraNode) -> int: def identify_data_dependency(self) -> None: """ - Identifies data dependency between nodes using tensors. + Identifies data dependencies between nodes using tensor input/output relationships. - Dependencies between nodes can be identified by their tensor input/output relationships. - A tensor can be identified by either a storage ID or a tensor ID. Use the storage ID if it's valid; - otherwise, use the tensor ID. + Determines the relationships based on whether the tensors use storage IDs or tensor IDs. """ - self.logger.info("Identify data dependency") + self.logger.info("Identifying data dependencies among nodes.") self.identify_data_dependency_with_storage_id() self.identify_data_dependency_with_tensor_id() def identify_data_dependency_with_storage_id(self) -> None: """ - Identifies data dependency between nodes with storage IDs. + Identifies data dependency between nodes based on storage IDs. + + Uses the mapping of input and output tensors to their storage IDs to establish dependencies. """ - self.logger.info("Identify data dependency with storage IDs") + self.logger.info("Identifying data dependencies using storage IDs.") self.update_data_dependencies(self.input_storage_id_nid_map, self.output_storage_id_nid_map) def identify_data_dependency_with_tensor_id(self) -> None: """ - Identifies data dependency between nodes with tensor IDs. + Identifies data dependency between nodes based on tensor IDs. + + Establishes dependencies using tensor IDs for tensors without valid storage IDs. """ - self.logger.info("Identify data dependency with tensor IDs") + self.logger.info("Identifying data dependencies using tensor IDs.") self.update_data_dependencies(self.input_tensor_id_nid_map, self.output_tensor_id_nid_map) def update_data_dependencies(self, input_map: Dict[int, List[int]], output_map: Dict[int, List[int]]) -> None: """ - Updates data dependencies for nodes based on input and output dictionaries. + Updates data dependencies for nodes based on input and output tensor maps. Args: - input_map (Dict[int, List[int]]): Dictionary mapping input IDs to node IDs. - output_map (Dict[int, List[int]]): Dictionary mapping output IDs to node IDs. + input_map (Dict[int, List[int]]): Map of input tensor IDs to node IDs. + output_map (Dict[int, List[int]]): Map of output tensor IDs to node IDs. """ + self.logger.debug("Updating data dependencies for nodes.") for input_id, child_nids in input_map.items(): if input_id in output_map: parent_nids = output_map[input_id] @@ -605,18 +657,22 @@ def update_data_dependencies(self, input_map: Dict[int, List[int]], output_map: def write_chakra_et(self) -> None: """ Writes the Chakra execution trace by encoding global metadata and nodes. - """ - self.logger.info("Starting to write Chakra trace.") + Encodes and writes both the metadata and individual nodes to create a complete execution trace. + """ + self.logger.info("Writing Chakra execution trace.") self._write_global_metadata() - self._encode_and_write_nodes() - - self.logger.info("Successfully wrote all Chakra nodes to the output file.") + self.logger.info("Chakra execution trace writing completed.") def _write_global_metadata(self) -> None: - """ Encodes and writes global metadata for the Chakra execution trace. """ - self.logger.info("Encoding global metadata.") + """ + Encodes and writes global metadata for the Chakra execution trace. + + This process includes encoding metadata like schema, process ID, timestamps, + and other relevant information for the Chakra execution trace. + """ + self.logger.info("Encoding global metadata for Chakra execution trace.") global_metadata = GlobalMetadata( attr=[ ChakraAttr(name="schema", string_val=self.pytorch_schema), @@ -629,14 +685,18 @@ def _write_global_metadata(self) -> None: encode_message(self.chakra_et, global_metadata) def _encode_and_write_nodes(self) -> None: - """ Encodes and writes nodes for the Chakra execution trace. """ - self.logger.info("Encoding nodes.") + """ + Encodes and writes nodes for the Chakra execution trace. + + Each node from the PyTorch execution trace is encoded and written into the + Chakra format. This includes node IDs, names, types, dependencies, and other attributes. + """ + self.logger.info("Encoding and writing nodes for Chakra execution trace.") seen_nids = set() for nid in sorted(self.chakra_nodes.keys()): if nid in seen_nids: - error_msg = f"Duplicate NID {nid} detected in Chakra nodes." - self.logger.error(error_msg) - raise ValueError(error_msg) + self.logger.error(f"Duplicate NID {nid} detected in Chakra nodes.") + raise ValueError(f"Duplicate NID {nid} detected.") seen_nids.add(nid) chakra_node = self.chakra_nodes[nid] encode_message(self.chakra_et, chakra_node) @@ -644,6 +704,9 @@ def _encode_and_write_nodes(self) -> None: def close_chakra_execution_trace(self) -> None: """ Closes the Chakra execution trace file if it is open. + + Ensures proper closure of the trace file to preserve data integrity. """ + self.logger.info("Closing Chakra execution trace file.") if self.chakra_et and not self.chakra_et.closed: self.chakra_et.close() diff --git a/et_converter/pytorch_tensor.py b/et_converter/pytorch_tensor.py index 26abd8d7..fdaa30fd 100644 --- a/et_converter/pytorch_tensor.py +++ b/et_converter/pytorch_tensor.py @@ -1,5 +1,6 @@ from typing import Any + class PyTorchTensor: """ Represents a tensor with its associated properties.