From db2a7736599723c0dc6bc187389aa3e150880d80 Mon Sep 17 00:00:00 2001 From: Marc Sun Date: Wed, 14 Aug 2024 13:26:04 +0200 Subject: [PATCH 1/4] destroy process group --- src/accelerate/accelerator.py | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/src/accelerate/accelerator.py b/src/accelerate/accelerator.py index 902b9f4dbc7..09a1f042052 100755 --- a/src/accelerate/accelerator.py +++ b/src/accelerate/accelerator.py @@ -2678,11 +2678,10 @@ def log(self, values: dict, step: int | None = None, log_kwargs: dict | None = { for tracker in self.trackers: tracker.log(values, step=step, **log_kwargs.get(tracker.name, {})) - @on_main_process def end_training(self): """ - Runs any special end training behaviors, such as stopping trackers on the main process only. Should always be - called at the end of your script if using experiment tracking. + Runs any special end training behaviors, such as stopping trackers on the main process only or destoying all process. + Should always be called at the end of your script if using experiment tracking. Example: @@ -2695,8 +2694,12 @@ def end_training(self): >>> accelerator.end_training() ``` """ - for tracker in self.trackers: - tracker.finish() + with self.on_main_process(): + for tracker in self.trackers: + tracker.finish() + if torch.distributed.is_initialized(): + # needed when using torch.distributed.init_process_group + torch.distributed.destroy_process_group() def save(self, obj, f, safe_serialization=False): """ From 472fe6febc7b49a3f07173d3f5531e527e4b10a3 Mon Sep 17 00:00:00 2001 From: Marc Sun Date: Wed, 14 Aug 2024 13:29:51 +0200 Subject: [PATCH 2/4] rephrase --- src/accelerate/accelerator.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/accelerate/accelerator.py b/src/accelerate/accelerator.py index 09a1f042052..e79c77d0690 100755 --- a/src/accelerate/accelerator.py +++ b/src/accelerate/accelerator.py @@ -2680,7 +2680,7 @@ def log(self, values: dict, step: int | None = None, log_kwargs: dict | None = { def end_training(self): """ - Runs any special end training behaviors, such as stopping trackers on the main process only or destoying all process. + Runs any special end training behaviors, such as stopping trackers on the main process only or destoying process group. Should always be called at the end of your script if using experiment tracking. Example: From a0a4d8b66a66496f8009e75469a9f062a359ab80 Mon Sep 17 00:00:00 2001 From: Marc Sun Date: Wed, 14 Aug 2024 16:12:56 +0200 Subject: [PATCH 3/4] style --- src/accelerate/accelerator.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/accelerate/accelerator.py b/src/accelerate/accelerator.py index e79c77d0690..10ea8a2a88f 100755 --- a/src/accelerate/accelerator.py +++ b/src/accelerate/accelerator.py @@ -2680,8 +2680,8 @@ def log(self, values: dict, step: int | None = None, log_kwargs: dict | None = { def end_training(self): """ - Runs any special end training behaviors, such as stopping trackers on the main process only or destoying process group. - Should always be called at the end of your script if using experiment tracking. + Runs any special end training behaviors, such as stopping trackers on the main process only or destoying + process group. Should always be called at the end of your script if using experiment tracking. Example: From 2856b9767828ba9a86007b4c8f06e9f474daddf5 Mon Sep 17 00:00:00 2001 From: Marc Sun Date: Wed, 14 Aug 2024 16:38:34 +0200 Subject: [PATCH 4/4] fix on_main_process --- src/accelerate/accelerator.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/accelerate/accelerator.py b/src/accelerate/accelerator.py index 10ea8a2a88f..a913fcbb3c2 100755 --- a/src/accelerate/accelerator.py +++ b/src/accelerate/accelerator.py @@ -2694,9 +2694,9 @@ def end_training(self): >>> accelerator.end_training() ``` """ - with self.on_main_process(): - for tracker in self.trackers: - tracker.finish() + for tracker in self.trackers: + tracker.finish() + if torch.distributed.is_initialized(): # needed when using torch.distributed.init_process_group torch.distributed.destroy_process_group()