diff --git a/run_time_test.py b/run_time_test.py index 3b561f4..3d09dab 100644 --- a/run_time_test.py +++ b/run_time_test.py @@ -1,6 +1,8 @@ import time from contextlib import nullcontext from typing import Callable +from collections import defaultdict +import json import torch import torch.utils._pytree as pytree @@ -10,6 +12,7 @@ from torch.utils._mode_utils import no_dispatch from torch.utils._python_dispatch import TorchDispatchMode from torch.utils.flop_counter import flop_registry +from torch.utils.module_tracker import ModuleTracker from fsdp_test import GPT, GPTConfig @@ -78,6 +81,10 @@ def __init__(self): } self.no_fallback_kernel = set() self.total_time: float = 0.0 + self.mod_tracker = ModuleTracker() + self.fwd_runtimes: Dict[str, int] = defaultdict(lambda: 0) + self.bwd_runtimes: Dict[str, int] = defaultdict(lambda: 0) + # Adapted from: https://github.com/pytorch/pytorch/blob/main/torch/_subclasses/fake_tensor.py#L1838 # NB: returns fake tensors def _maybe_run_and_benchmark_fallback_kernel( @@ -176,11 +183,17 @@ def _dispatch_benchmark_estimate(self, func, args, kwargs): ) ) self.total_time += mean_op_time + for par in self.mod_tracker.parents: + if self.mod_tracker.is_bw: + self.bwd_runtimes[par] += mean_op_time + else: + self.fwd_runtimes[par] += mean_op_time return res except NotImplementedError: self.no_fallback_kernel.add(func._overloadpacket) res = func(*args, **kwargs or {}) return res + # Adapted from: https://github.com/pytorch/pytorch/blob/main/torch/_inductor/scheduler.py#L563 def _dispatch_inductor_estimate(self, func, args, kwargs): def get_num_bytes(t: torch.Tensor) -> int: @@ -251,6 +264,11 @@ def get_transfer_time(flat_args_kwargs, flat_outs): # compute time. We divide by 1e6 to get the time in ms op_time = max(transfer_time, compute_time) / 1e6 self.total_time += op_time + for par in self.mod_tracker.parents: + if self.mod_tracker.is_bw: + self.bwd_runtimes[par] += op_time + else: + self.fwd_runtimes[par] += op_time return out @@ -280,6 +298,9 @@ def __enter__(self): self.fake_mode = fake_mode self.total_time = 0.0 super().__enter__() + self.fwd_runtimes.clear() + self.bwd_runtimes.clear() + self.mod_tracker.__enter__() return self def __exit__(self, *args): @@ -289,6 +310,13 @@ def __exit__(self, *args): ) if len(self.no_fallback_kernel) > 0: print("no_fallback_kernel: ", list(self.no_fallback_kernel)) + print( + "===== FORWARD RUN TIMES =====\n", + json.dumps(dict(self.fwd_runtimes), indent=4), + "===== BACKWARD RUN TIMES =====\n", + json.dumps(dict(self.bwd_runtimes), indent=4) + ) + self.mod_tracker.__exit__() return super().__exit__(*args)