diff --git a/src/trace_link/trace_link.py b/src/trace_link/trace_link.py index bb51a870..e48103dd 100644 --- a/src/trace_link/trace_link.py +++ b/src/trace_link/trace_link.py @@ -29,23 +29,23 @@ class KinetoOperator: Attributes: op_dict (Dict[str, Any]): Dictionary containing the operator data. - category (Optional[str]): Category of the operator. - name (Optional[str]): Name of the operator. + category (str): Category of the operator. + name (str): Name of the operator. phase (Optional[str]): Phase of the operator. inclusive_dur (int): Inclusive duration of the operator in microseconds. exclusive_dur (int): Exclusive duration of the operator in microseconds. timestamp (int): Timestamp of the operator in microseconds. - external_id (Optional[str]): External ID associated with the operator. - ev_idx (Optional[str]): Event index associated with the operator. - tid (Optional[int]): Thread ID associated with the operator. + external_id (str): External ID associated with the operator. + ev_idx (str): Event index associated with the operator. + tid (int): Thread ID associated with the operator. pytorch_op (Optional[PyTorchOperator]): Associated PyTorch operator. parent_pytorch_op_id (Optional[int]): ID of the parent PyTorch operator. inter_thread_dep (Optional[int]): ID of the latest CPU node from other threads before the gap. stream (Optional[int]): Stream ID associated with the operator. rf_id (Optional[int]): Record function ID. - correlation (Optional[int]): Correlation ID used to link CUDA runtime - operations with their GPU counterparts. + correlation (int): Correlation ID used to link CUDA runtime operations + with their GPU counterparts. """ def __init__(self, kineto_op: Dict[str, Any]) -> None: @@ -57,25 +57,25 @@ def __init__(self, kineto_op: Dict[str, Any]) -> None: operator data. """ self.op_dict = kineto_op - self.category = kineto_op.get("cat") - self.name = kineto_op.get("name") + self.category = kineto_op.get("cat", "") + self.name = kineto_op.get("name", "") self.phase = kineto_op.get("ph") self.inclusive_dur = kineto_op.get("dur", 0) self.exclusive_dur = kineto_op.get("dur", 0) self.timestamp = kineto_op.get("ts", 0) - self.external_id = None - self.ev_idx = None - self.tid = kineto_op.get("tid") + self.external_id = "" + self.ev_idx = "" + self.tid = kineto_op.get("tid", 0) self.pytorch_op: Optional[PyTorchOperator] = None self.parent_pytorch_op_id = None self.inter_thread_dep: Optional[int] = None self.stream: Optional[int] = None self.rf_id: Optional[int] = None - self.correlation: Optional[int] = None + self.correlation: int = None if "args" in kineto_op: self.external_id = kineto_op["args"].get("External id") - self.ev_idx = kineto_op["args"].get("Ev Idx") + self.ev_idx = kineto_op["args"].get("Ev Idx", "") self.stream = kineto_op["args"].get("stream") if "Record function id" in kineto_op["args"]: self.rf_id = int(kineto_op["args"]["Record function id"]) @@ -761,7 +761,7 @@ def group_gpu_ops_by_cpu_launchers(self) -> Dict[str, List[KinetoOperator]]: self.logger.warning(warning_msg) continue - if parent_cpu_op.ev_idx is None: + if parent_cpu_op.ev_idx == "": error_msg = ( f"Missing 'ev_idx' for CPU operator {parent_cpu_op.name}. " f"Cannot link to GPU op {gpu_op.name} to {parent_cpu_op.name}." @@ -912,7 +912,7 @@ def link_ops( self.pytorch_op_id_to_exclusive_dur_map[pytorch_op.id] = kineto_op.exclusive_dur self.pytorch_op_id_to_timestamp_map[pytorch_op.id] = kineto_op.timestamp if kineto_op.inter_thread_dep: - inter_thread_dep_kineto_op = self.kineto_rf_id_to_kineto_op_map[kineto_op.inter_thread_dep] + inter_thread_dep_kineto_op = self.kineto_rf_id_to_kineto_op_map[str(kineto_op.inter_thread_dep)] if inter_thread_dep_kineto_op.pytorch_op: self.pytorch_op_id_to_inter_thread_dep_map[pytorch_op.id] = inter_thread_dep_kineto_op.pytorch_op.id if kineto_op.ev_idx in cpu_ev_idx_to_gpu_ops_map: