Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support multi-gpu child ops and remove CPU/GPU op splitting #28

Merged
merged 3 commits into from
Apr 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
165 changes: 4 additions & 161 deletions et_converter/pytorch2chakra_converter.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
#!/usr/bin/env python3

import copy
import json
import logging
from typing import Dict, List, Optional, Tuple, Set
Expand Down Expand Up @@ -161,16 +160,13 @@ def convert(self) -> None:

self.open_chakra_execution_trace()

self.split_cpu_nodes_with_gpu_child()

for pytorch_nid, pytorch_node in self.pytorch_nodes.items():
if (pytorch_node.get_op_type() == PyTorchNodeType.CPU_OP)\
or (pytorch_node.get_op_type() == PyTorchNodeType.LABEL):
chakra_node = self.convert_to_chakra_node(pytorch_node)
self.chakra_nodes[chakra_node.id] = chakra_node

if pytorch_node.child_gpu:
pytorch_gpu_node = pytorch_node.child_gpu
for pytorch_gpu_node in pytorch_node.gpu_children:
chakra_gpu_node = self.convert_to_chakra_node(pytorch_gpu_node)

if chakra_node.type == COMM_COLL_NODE:
Expand Down Expand Up @@ -273,7 +269,7 @@ def _establish_parent_child_relationships(
parent_node.add_child(pytorch_node)

if pytorch_node.is_gpu_op():
parent_node.set_child_gpu(pytorch_node)
parent_node.add_gpu_child(pytorch_node)

if pytorch_node.is_record_param_comms_op():
parent_node.record_param_comms_node = pytorch_node
Expand Down Expand Up @@ -318,160 +314,6 @@ def open_chakra_execution_trace(self) -> None:
self.logger.error(err_msg)
raise Exception(err_msg)

def split_cpu_nodes_with_gpu_child(self) -> None:
"""
Decomposes CPU nodes with GPU child nodes to model execution overlap
accurately. This method addresses scenarios where a CPU node has a GPU
child node, with an overlap in their execution ending at the same time.
The method splits the CPU node into:
1. Non-Overlapping Part: Segment before the GPU node starts.
2. Overlapping Part: Segment overlapping with the GPU node.

Timeline Stages:
Stage 1 - Original Scenario:
|------------ CPU Node ------------|
|--- GPU Node ---|

Stage 2 - After Split:
|-- Non-Overlap --|--- Overlap ----|
|--- GPU Node ---|

Raises:
ValueError: If timestamps of GPU and CPU nodes are inconsistent.
"""
self.logger.info("Decomposing CPU nodes with GPU child nodes.")
updated_pytorch_nodes: Dict[int, PyTorchNode] = {}
for cpu_node in self.pytorch_nodes.values():
if cpu_node.child_gpu is None:
new_cpu_node_id = self.id_assigner.assign_unique_id(cpu_node.id)
cpu_node.id = new_cpu_node_id
for child_node in cpu_node.children:
child_node.parent = cpu_node.id
updated_pytorch_nodes[new_cpu_node_id] = cpu_node
else:
if cpu_node.exclusive_dur > 1:
gpu_node = cpu_node.child_gpu
cpu_node_first, cpu_node_second, updated_gpu_node =\
self._split_cpu_node(cpu_node, gpu_node, updated_pytorch_nodes)
updated_pytorch_nodes[cpu_node_first.id] = copy.deepcopy(cpu_node_first)
updated_pytorch_nodes[cpu_node_second.id] = copy.deepcopy(cpu_node_second)
updated_pytorch_nodes[updated_gpu_node.id] = copy.deepcopy(updated_gpu_node)
else:
new_cpu_node_id = self.id_assigner.assign_unique_id(cpu_node.id)
cpu_node.id = new_cpu_node_id
for child_node in cpu_node.children:
child_node.parent = cpu_node.id
updated_pytorch_nodes[new_cpu_node_id] = cpu_node

gpu_node = cpu_node.child_gpu
gpu_node.parent = new_cpu_node_id
new_gpu_node_id = self.id_assigner.assign_unique_id(gpu_node.id)
updated_pytorch_nodes[new_gpu_node_id] = gpu_node

self.pytorch_nodes = updated_pytorch_nodes

def _split_cpu_node(
self, cpu_node: PyTorchNode, gpu_node: PyTorchNode,
updated_pytorch_nodes: Dict[int, PyTorchNode]
) -> Tuple[PyTorchNode, PyTorchNode, PyTorchNode]:
"""
Splits a CPU node based on the GPU node's timestamp.

Args:
cpu_node (PyTorchNode): Original CPU node to be split.
gpu_node (PyTorchNode): GPU node dictating the split.
updated_pytorch_nodes (Dict[int, PyTorchNode]): Updated PyTorch nodes.

Returns:
Tuple[PyTorchNode, PyTorchNode, PyTorchNode]: Two split nodes and
the updated GPU node.

Raises:
ValueError: For inconsistencies in the timestamps of the nodes.
"""
original_cpu_info = f"Original CPU Node ID {cpu_node.id} ({cpu_node.name}), " \
f"Inclusive Duration: {cpu_node.inclusive_dur}, " \
f"Exclusive Duration: {cpu_node.exclusive_dur}."
self.logger.debug(original_cpu_info)
self.logger.debug(f"GPU Node ID {gpu_node.id} ({gpu_node.name}), "
f"Inclusive Duration: {gpu_node.inclusive_dur}, "
f"Exclusive Duration: {gpu_node.exclusive_dur}.")

cpu_node_first = copy.deepcopy(cpu_node)
cpu_node_first.id = self.id_assigner.assign_unique_id(cpu_node.id)
cpu_node_first.ts = cpu_node.ts
cpu_node_first.exclusive_dur = int(cpu_node.exclusive_dur / 2)
cpu_node_first.set_child_gpu(gpu_node)
if cpu_node_first.ts >= gpu_node.ts or cpu_node_first.inclusive_dur <= 0:
err_msg = (f"Invalid timestamps for the first split CPU node derived from {original_cpu_info}\n"
f"\tFirst Split CPU Node Timestamp: {cpu_node_first.ts}, \n"
f"\tGPU Node Timestamp: {gpu_node.ts}, \n"
f"\tFirst Split CPU Node Inclusive Duration: {cpu_node_first.inclusive_dur}, \n"
f"\tFirst Split CPU Node Exclusive Duration: {cpu_node_first.exclusive_dur}.")
self.logger.error(err_msg)
raise ValueError(err_msg)

if cpu_node.parent in self.pytorch_nodes:
self._update_parent_node_children(self.pytorch_nodes, cpu_node, cpu_node_first)
elif cpu_node.parent in updated_pytorch_nodes:
self._update_parent_node_children(updated_pytorch_nodes, cpu_node, cpu_node_first)

self.logger.debug(f"First Split CPU Node ID {cpu_node_first.id} ({cpu_node_first.name}), "
f"Inclusive Duration: {cpu_node_first.inclusive_dur}, "
f"Exclusive Duration: {cpu_node_first.exclusive_dur}.")

gpu_node_id = self.id_assigner.assign_unique_id(gpu_node.id)
gpu_node.id = gpu_node_id
gpu_node.parent = cpu_node_first.id

cpu_node_second = copy.deepcopy(cpu_node)
cpu_node_second.id = self.id_assigner.assign_unique_id(cpu_node.id)
cpu_node_second.ts = gpu_node.ts
cpu_node_second.exclusive_dur = int(cpu_node.exclusive_dur / 2)
cpu_node_second.set_child_gpu(None)
cpu_node_second.parent = cpu_node_first.id
for child_node in cpu_node.children:
child_node.parent = cpu_node_second.id
cpu_node_second.add_child(child_node)
if cpu_node_second.ts <= cpu_node_first.ts or cpu_node_second.inclusive_dur <= 0:
err_msg = (f"Invalid timestamps for the second split CPU node derived from {original_cpu_info}\n"
f"\tFirst Split Timestamp: {cpu_node_first.ts}, \n"
f"\tSecond Split Timestamp: {cpu_node_second.ts}, \n"
f"\tSecond Split Inclusive Duration: {cpu_node_second.inclusive_dur}, "
f"\tSecond Split Exclusive Duration: {cpu_node_second.exclusive_dur}.")
self.logger.error(err_msg)
raise ValueError(err_msg)

self.logger.debug(f"Second Split CPU Node ID {cpu_node_second.id} ({cpu_node_second.name}), "
f"Inclusive Duration: {cpu_node_second.inclusive_dur}, "
f"Exclusive Duration: {cpu_node_second.exclusive_dur}.")

cpu_node_first.add_child(cpu_node_second)
cpu_node_first.add_child(gpu_node)

return cpu_node_first, cpu_node_second, gpu_node

def _update_parent_node_children(self, parent_node_dict: Dict[int, PyTorchNode],
cpu_node: PyTorchNode,
cpu_node_first: PyTorchNode) -> None:
"""
Updates the children of the parent node in the given dictionary.

This method removes the original CPU node from the parent's children list
and adds the first split node.

Args:
parent_node_dict (Dict[int, PyTorchNode]): Dictionary containing the
parent node.
cpu_node (PyTorchNode): Original CPU node being split.
cpu_node_first (PyTorchNode): First split node to add to the parent's
children.
"""
parent_node = parent_node_dict[cpu_node.parent]
parent_node.children = [child for child in parent_node.children
if child.id != cpu_node.id]
parent_node.children.extend([cpu_node_first])

def convert_to_chakra_node(self, pytorch_node: PyTorchNode) -> ChakraNode:
"""
Converts a PyTorchNode to a ChakraNode.
Expand Down Expand Up @@ -708,7 +550,8 @@ def remove_dangling_nodes(self) -> None:
if node_id not in parent_ids and not node.data_deps:
dangling_nodes.append(node)
del self.chakra_nodes[node_id]
del self.pytorch_nodes[node_id]
if node_id in self.pytorch_nodes:
del self.pytorch_nodes[node_id]

if dangling_nodes:
self.logger.info(f"Identified and removed {len(dangling_nodes)} dangling nodes:")
Expand Down
15 changes: 9 additions & 6 deletions et_converter/pytorch_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def __init__(self, node_data: Dict[str, Any]) -> None:
self.node_data = node_data
self.data_deps: List['PyTorchNode'] = []
self.children: List['PyTorchNode'] = []
self.child_gpu: Optional['PyTorchNode'] = None
self.gpu_children: List['PyTorchNode'] = []
self.record_param_comms_node: Optional['PyTorchNode'] = None
self.nccl_node: Optional['PyTorchNode'] = None

Expand Down Expand Up @@ -419,7 +419,9 @@ def inclusive_dur(self) -> int:
Returns:
int: The inclusive duration of the node.
"""
return self.node_data["inclusive_dur"]
if "inclusive_dur" in self.node_data:
return self.node_data["inclusive_dur"]
return 0

@inclusive_dur.setter
def inclusive_dur(self, value: int) -> None:
Expand Down Expand Up @@ -543,14 +545,14 @@ def add_child(self, child_node: 'PyTorchNode') -> None:
"""
self.children.append(child_node)

def set_child_gpu(self, child_gpu_node: Optional['PyTorchNode']) -> None:
def add_gpu_child(self, gpu_child_node: 'PyTorchNode') -> None:
"""
Sets a child GPU node for this node.
Adds a child GPU node for this node.

Args:
child_gpu_node (Optional[PyTorchNode]): The child GPU node to be set.
gpu_child_node (Optional[PyTorchNode]): The child GPU node to be added.
"""
self.child_gpu = child_gpu_node
self.gpu_children.append(gpu_child_node)

def is_record_param_comms_op(self) -> bool:
"""
Expand Down Expand Up @@ -620,6 +622,7 @@ def get_data_type_size(data_type: str) -> int:
"Tensor(int64)": 8,
"Tensor(long)": 8,
"Tensor(c10::Half)": 2,
"Tensor(c10::BFloat16)": 2,
"Tensor(unsigned char)": 1,
"Tensor(long int)": 8,
# TODO: Add more types
Expand Down
Loading