From 3cd145c8d4b91e80d0811e0c5fec4309228cb80c Mon Sep 17 00:00:00 2001 From: Siming Dai <908660116@qq.com> Date: Wed, 16 Oct 2024 15:50:45 +0800 Subject: [PATCH] [Unified Checkpoint] update async save logic (#9274) (#9275) * update async save signal * fix async save hang --- .../trainer/plugins/unified_checkpoint.py | 53 +++++++++++--- paddlenlp/trainer/trainer.py | 71 +++++++++++++------ paddlenlp/trainer/trainer_utils.py | 8 ++- paddlenlp/trainer/training_args.py | 5 ++ 4 files changed, 103 insertions(+), 34 deletions(-) diff --git a/paddlenlp/trainer/plugins/unified_checkpoint.py b/paddlenlp/trainer/plugins/unified_checkpoint.py index 2bfb6b0d5a15..f826ebfb620f 100644 --- a/paddlenlp/trainer/plugins/unified_checkpoint.py +++ b/paddlenlp/trainer/plugins/unified_checkpoint.py @@ -136,7 +136,6 @@ def __init__(self, args): self._process_master_weight = None self._process_optimizer_weight = None self._lock = None - self._shared_save_path = None self._shared_save_model_flag = None self._shared_save_master_weight_flag = None self._shared_save_optimizer_flag = None @@ -144,13 +143,18 @@ def __init__(self, args): if "async_save" in self.args.unified_checkpoint_config: self._lock = multiprocessing.Lock() self._shared_save_model_path = multiprocessing.Array("c", 100000) + self._shared_save_model_signal_path = multiprocessing.Array("c", 100000) self._shared_save_master_weight_path = multiprocessing.Array("c", 100000) + self._shared_save_master_weight_signal_path = multiprocessing.Array("c", 100000) self._shared_save_optimizer_path = multiprocessing.Array("c", 100000) + self._shared_save_optimizer_signal_path = multiprocessing.Array("c", 100000) self._shared_save_model_flag = multiprocessing.Array("i", 1) self._shared_save_master_weight_flag = multiprocessing.Array("i", 1) self._shared_save_optimizer_flag = multiprocessing.Array("i", 1) - def _file_save_async_or_sync(self, state_dict, path, is_sync=True, state_dict_type="model_weight"): + def _file_save_async_or_sync( + self, state_dict, path, signal_path=None, is_sync=True, state_dict_type="model_weight" + ): if is_sync: for k in list(state_dict.keys()): if isinstance(state_dict[k], paddle.Tensor): @@ -165,6 +169,7 @@ def _file_save_async_or_sync(self, state_dict, path, is_sync=True, state_dict_ty meta_dict = self._meta_dict_model shared_save_flag = self._shared_save_model_flag shared_save_path = self._shared_save_model_path + shared_save_signal_path = self._shared_save_model_signal_path if self._process_model_weight is None: self._process_model_weight = multiprocessing.Process( target=self._save_file_async_in_process, @@ -173,12 +178,14 @@ def _file_save_async_or_sync(self, state_dict, path, is_sync=True, state_dict_ty self._shm_model_weight.name, self._shared_save_model_flag, self._shared_save_model_path, + self._shared_save_model_signal_path, self._lock, state_dict_type, self.global_rank, ), ) self._process_model_weight.start() + process = self._process_model_weight elif state_dict_type == "master_weight": if self._shm_master_weight is None: self._meta_dict_master_weight, buffer_size = create_meta_dict(state_dict) @@ -187,6 +194,7 @@ def _file_save_async_or_sync(self, state_dict, path, is_sync=True, state_dict_ty meta_dict = self._meta_dict_master_weight shared_save_flag = self._shared_save_master_weight_flag shared_save_path = self._shared_save_master_weight_path + shared_save_signal_path = self._shared_save_master_weight_signal_path if self._process_master_weight is None: self._process_master_weight = multiprocessing.Process( target=self._save_file_async_in_process, @@ -195,6 +203,7 @@ def _file_save_async_or_sync(self, state_dict, path, is_sync=True, state_dict_ty self._shm_master_weight.name, self._shared_save_master_weight_flag, self._shared_save_master_weight_path, + self._shared_save_master_weight_signal_path, self._lock, "model_weight" if "skip_save_model_weight" in self.args.unified_checkpoint_config @@ -203,6 +212,7 @@ def _file_save_async_or_sync(self, state_dict, path, is_sync=True, state_dict_ty ), ) self._process_master_weight.start() + process = self._process_master_weight elif state_dict_type == "optimizer_weight": if self._shm_optimizer_weight is None: self._meta_dict_optim, buffer_size = create_meta_dict(state_dict) @@ -211,6 +221,7 @@ def _file_save_async_or_sync(self, state_dict, path, is_sync=True, state_dict_ty meta_dict = self._meta_dict_optim shared_save_flag = self._shared_save_optimizer_flag shared_save_path = self._shared_save_optimizer_path + shared_save_signal_path = self._shared_save_optimizer_signal_path if self._process_optimizer_weight is None: self._process_optimizer_weight = multiprocessing.Process( target=self._save_file_async_in_process, @@ -219,21 +230,26 @@ def _file_save_async_or_sync(self, state_dict, path, is_sync=True, state_dict_ty self._shm_optimizer_weight.name, self._shared_save_optimizer_flag, self._shared_save_optimizer_path, + self._shared_save_optimizer_signal_path, self._lock, state_dict_type, self.global_rank, ), ) self._process_optimizer_weight.start() + process = self._process_optimizer_weight while True: # wait until no process is saving. flag_value = shared_save_flag[0] if flag_value == 0: break + if not process.is_alive(): + raise RuntimeError(f"The process that saves {state_dict_type} has been killed unexpectedly.") time.sleep(0.5) logger.info(f"Wait for the previous save process to finish saving {state_dict_type}") # only save model weight or save master weight, we enter this loop. self._reset_and_update(shared_save_path, path) + self._reset_and_update(shared_save_signal_path, signal_path) _traverse_copy_to_shm(state_dict, meta_dict, shm_state_dict.buf) with self._lock: shared_save_flag[0] = 1 @@ -244,6 +260,7 @@ def _save_file_async_in_process( shm_name, shared_save_flag, shared_save_path, + shared_save_signal_path, lock, state_dict_type, global_rank, @@ -257,11 +274,12 @@ def _save_file_async_in_process( continue if flag_value == 1: # need to save path = shared_save_path[:].decode("utf-8").rstrip("\x00") + signal_path = shared_save_signal_path[:].decode("utf-8").rstrip("\x00") logger.info(f"Start to async save {path}") state_dict = _read_state_dict_from_shm(meta_dict, shm) # numpy array safe_save_file(state_dict, path, {"format": "np"}) del state_dict - saved_signal_path = os.path.join(os.path.dirname(path), f".{state_dict_type}.done.{global_rank}") + saved_signal_path = os.path.join(signal_path, f".{state_dict_type}.done.{global_rank}") paddle.save(global_rank, saved_signal_path) with lock: shared_save_flag[0] = 0 @@ -276,7 +294,7 @@ def _reset_and_update(self, shared_array, new_value): encoded_value = new_value.encode("utf-8") shared_array[: len(encoded_value)] = encoded_value - def save_unified_checkpoint(self, model, optimizer, output_dir): + def save_unified_checkpoint(self, model, optimizer, output_dir, signal_dir=None): """save unified checkpoint Args: @@ -313,6 +331,8 @@ def save_unified_checkpoint(self, model, optimizer, output_dir): save_directory = output_dir os.makedirs(save_directory, exist_ok=True) + if signal_dir is not None: + os.makedirs(signal_dir, exist_ok=True) # only for async save # save model weights if not skip_save_model_weight: @@ -325,6 +345,7 @@ def save_unified_checkpoint(self, model, optimizer, output_dir): self._file_save_async_or_sync( state_dict, path=os.path.join(save_directory, shard_file), + signal_path=signal_dir, is_sync=is_sync_save, state_dict_type="model_weight", ) @@ -393,7 +414,7 @@ def load_unified_checkpoint(self, model, optimizer, resume_from_checkpoint: str) if self.args.dataset_rank == 0: load_unified_checkpoint_locally(self.args, model, resume_from_checkpoint, safe_serialization=True) - def save_non_merge_optimizer(self, model, optimizer, output_dir): + def save_non_merge_optimizer(self, model, optimizer, output_dir, signal_dir): paddle.device.cuda.empty_cache() optim_state_dict = nested_copy(optimizer.state_dict()) master_weights = None @@ -432,12 +453,14 @@ def save_non_merge_optimizer(self, model, optimizer, output_dir): self._file_save_async_or_sync( optim_state_dict, path=os.path.join(output_dir, optimizer_name), + signal_path=signal_dir, is_sync=is_sync_save, state_dict_type="optimizer_weight", ) self._file_save_async_or_sync( master_weights, path=os.path.join(output_dir, master_weights_name), + signal_path=signal_dir, is_sync=is_sync_save, state_dict_type="master_weight", ) @@ -484,22 +507,23 @@ def load_non_merge_optimizer(self, model, optimizer, resume_from_checkpoint): return returned_optim_state_dict - def save_unified_optimizer(self, model, optimizer, output_dir): + def save_unified_optimizer(self, model, optimizer, output_dir, signal_dir): """save unified optimizer Args: model (PretrainedModel): model used to get key mapping. optimizer (Optimizer): optimizer to save output_dir (str): Save directory. + signal_dir (str): Asynchronous saving signal directory. """ if "ignore_merge_optimizer" in self.args.unified_checkpoint_config: - self.save_non_merge_optimizer(model, optimizer, output_dir) + self.save_non_merge_optimizer(model, optimizer, output_dir, signal_dir) return if paddle.distributed.get_world_size() <= 1: - self.save_single_card_optimizer(model, optimizer, output_dir) + self.save_single_card_optimizer(model, optimizer, output_dir) # no need to save signal return # Split into naive optimizer params and master weights. @@ -515,6 +539,8 @@ def save_unified_optimizer(self, model, optimizer, output_dir): save_directory = output_dir os.makedirs(save_directory, exist_ok=True) + if signal_dir is not None: + os.makedirs(signal_dir, exist_ok=True) is_sync_save = True if "async_save" in self.args.unified_checkpoint_config: @@ -522,6 +548,7 @@ def save_unified_optimizer(self, model, optimizer, output_dir): self._file_save_async_or_sync( optim_state_dict, path=os.path.join(save_directory, shard_optim_file), + signal_path=signal_dir, is_sync=is_sync_save, state_dict_type="optimizer_weight", ) @@ -529,6 +556,7 @@ def save_unified_optimizer(self, model, optimizer, output_dir): self._file_save_async_or_sync( master_weight_state_dict, path=os.path.join(save_directory, shard_master_weight_file), + signal_path=signal_dir, is_sync=is_sync_save, state_dict_type="master_weight", ) @@ -716,14 +744,20 @@ def unlink_shared_memory(self): if self._shared_save_model_flag is not None: while self._shared_save_model_flag[0] > 0: # async process is saving + if not self._process_model_weight.is_alive(): + raise RuntimeError("The process that saves model_weight has been killed unexpectedly.") time.sleep(0.5) self._shared_save_model_flag[0] = -1 if self._shared_save_master_weight_flag is not None: while self._shared_save_master_weight_flag[0] > 0: + if not self._process_master_weight.is_alive(): + raise RuntimeError("The process that saves master_weight has been killed unexpectedly.") time.sleep(0.5) self._shared_save_master_weight_flag[0] = -1 if self._shared_save_optimizer_flag is not None: while self._shared_save_optimizer_flag[0] > 0: + if not self._process_optimizer_weight.is_alive(): + raise RuntimeError("The process that saves optimizer_weight has been killed unexpectedly.") time.sleep(0.5) self._shared_save_optimizer_flag[0] = -1 @@ -740,7 +774,8 @@ def unlink_shared_memory(self): self._shm_optimizer_weight.unlink() self._shm_optimizer_weight = None - dist.barrier() + if paddle.distributed.get_world_size() > 1: + dist.barrier() def load_unified_checkpoint_locally(args, model, resume_from_checkpoint: str, safe_serialization=False): diff --git a/paddlenlp/trainer/trainer.py b/paddlenlp/trainer/trainer.py index 9a3c894ada8e..5e7a276b10bb 100644 --- a/paddlenlp/trainer/trainer.py +++ b/paddlenlp/trainer/trainer.py @@ -2203,7 +2203,12 @@ def training_pipeline_step(self, model: nn.Layer, inputs: Dict[str, Union[paddle return loss.detach() - def save_model(self, output_dir: Optional[str] = None, merge_tensor_parallel: Optional[bool] = False): + def save_model( + self, + output_dir: Optional[str] = None, + merge_tensor_parallel: Optional[bool] = False, + signal_dir: Optional[str] = None, + ): """ Will save the model, so you can reload it using `from_pretrained()`. @@ -2213,17 +2218,20 @@ def save_model(self, output_dir: Optional[str] = None, merge_tensor_parallel: Op if output_dir is None: output_dir = self.args.output_dir + if signal_dir is None: + signal_dir = self.args.output_signal_dir + if ShardingOption.FULL_SHARD in self.args.sharding: self.model_wrapped.get_all_parameters(convert2cpu=True) if self.args.should_save_model_state: - self._save(output_dir=output_dir, merge_tensor_parallel=merge_tensor_parallel) + self._save(output_dir=output_dir, merge_tensor_parallel=merge_tensor_parallel, signal_dir=signal_dir) else: if self.args.unified_checkpoint and "async_save" in self.args.unified_checkpoint_config: - os.makedirs(output_dir, exist_ok=True) + os.makedirs(signal_dir, exist_ok=True) if self.is_in_train: global_rank = paddle.distributed.get_rank() if paddle.distributed.get_world_size() > 1 else -1 - paddle.save(global_rank, os.path.join(output_dir, f".model_weight.done.{global_rank}")) + paddle.save(global_rank, os.path.join(signal_dir, f".model_weight.done.{global_rank}")) if strtobool(os.getenv("FLAG_LLM_PDC", "False")): # save model_done file to ensure model is complete @@ -2239,9 +2247,9 @@ def save_model(self, output_dir: Optional[str] = None, merge_tensor_parallel: Op and "async_save" in self.args.unified_checkpoint_config and not self.is_in_train ): - os.makedirs(output_dir, exist_ok=True) + os.makedirs(signal_dir, exist_ok=True) global_rank = paddle.distributed.get_rank() if paddle.distributed.get_world_size() > 1 else -1 - paddle.save(self.state.global_step, os.path.join(output_dir, f".model_weight.done.{global_rank}")) + paddle.save(self.state.global_step, os.path.join(signal_dir, f".model_weight.done.{global_rank}")) def _filter_moe_no_sync_optimizer_params(self): """ @@ -2252,7 +2260,7 @@ def _filter_moe_no_sync_optimizer_params(self): filter_optimzier_state_dict = OrderedDict() param_names_in_master_weights = list(optimzier_state_dict["master_weights"].keys()) if self.args.bf16 else [] filter_optimzier_state_dict["master_weights"] = OrderedDict() - for k, v in state_dict.items(): + for _, v in state_dict.items(): if getattr(v, "no_sync", False): if v.name in param_names_in_master_weights: filter_optimzier_state_dict["master_weights"][v.name] = optimzier_state_dict["master_weights"][ @@ -2270,15 +2278,17 @@ def _save_checkpoint(self, model, metrics=None): checkpoint_folder = f"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}" run_dir = self.args.output_dir + run_signal_dir = self.args.output_signal_dir output_dir = os.path.join(run_dir, checkpoint_folder) + signal_dir = os.path.join(run_signal_dir, checkpoint_folder) if isinstance(self.model, LoRAModel) and (self.model.quantized or self.args.pipeline_parallel_degree > 1): - self.save_model(output_dir) + self.save_model(output_dir, False, signal_dir) elif isinstance(self.model, LoRAModel) or isinstance(self.model, PrefixModelForCausalLM): - self.save_model(output_dir, True) + self.save_model(output_dir, True, signal_dir) else: - self.save_model(output_dir) + self.save_model(output_dir, False, signal_dir) # only save model state dict, ignore optimizer and scheduler if not self.args.ignore_save_lr_and_optim: @@ -2293,6 +2303,7 @@ def _save_checkpoint(self, model, metrics=None): self.model, self.optimizer, output_dir, + signal_dir, ) else: if self.dp_group.rank > 0: # this should only work for MoE saving @@ -2308,10 +2319,10 @@ def _save_checkpoint(self, model, metrics=None): else: if self.args.unified_checkpoint and "async_save" in self.args.unified_checkpoint_config: global_rank = paddle.distributed.get_rank() if paddle.distributed.get_world_size() > 1 else -1 - os.makedirs(output_dir, exist_ok=True) - paddle.save(global_rank, os.path.join(output_dir, f".optimizer_weight.done.{global_rank}")) + os.makedirs(signal_dir, exist_ok=True) + paddle.save(global_rank, os.path.join(signal_dir, f".optimizer_weight.done.{global_rank}")) if "skip_save_model_weight" not in self.args.unified_checkpoint_config: - paddle.save(global_rank, os.path.join(output_dir, f".master_weight.done.{global_rank}")) + paddle.save(global_rank, os.path.join(signal_dir, f".master_weight.done.{global_rank}")) if self.args.should_save or self.args.use_expert_parallel: if not self.args.use_hybrid_parallel: logger.info("Saving optimizer files.") @@ -2320,6 +2331,7 @@ def _save_checkpoint(self, model, metrics=None): self.model, self.optimizer, output_dir, + signal_dir, ) else: if self.args.data_parallel_rank > 0 and self.args.use_expert_parallel: @@ -2337,11 +2349,11 @@ def _save_checkpoint(self, model, metrics=None): else: if self.args.unified_checkpoint and not self.args.use_hybrid_parallel: if "async_save" in self.args.unified_checkpoint_config: - os.makedirs(output_dir, exist_ok=True) global_rank = paddle.distributed.get_rank() if paddle.distributed.get_world_size() > 1 else -1 - paddle.save(global_rank, os.path.join(output_dir, f".optimizer_weight.done.{global_rank}")) + os.makedirs(signal_dir, exist_ok=True) + paddle.save(global_rank, os.path.join(signal_dir, f".optimizer_weight.done.{global_rank}")) if "skip_save_model_weight" not in self.args.unified_checkpoint_config: - paddle.save(global_rank, os.path.join(output_dir, f".master_weight.done.{global_rank}")) + paddle.save(global_rank, os.path.join(signal_dir, f".master_weight.done.{global_rank}")) self.runtime_timer.stop() # Determine the new best metric / best model checkpoint @@ -2390,7 +2402,7 @@ def _save_checkpoint(self, model, metrics=None): # For hybrid parallel training, the checkpoint files maybe on different node. need_to_rotate_checkpoints = False if self.args.use_hybrid_parallel: - if self.dp_group.rank <= 0: + if self.dp_group.rank <= 0 or self.args.use_expert_parallel: need_to_rotate_checkpoints = True else: need_to_rotate_checkpoints = self.args.should_save_model_state @@ -2399,6 +2411,7 @@ def _save_checkpoint(self, model, metrics=None): need_to_rotate_checkpoints = need_to_rotate_checkpoints and self.args.local_rank == 0 if need_to_rotate_checkpoints: self._rotate_checkpoints(use_mtime=True, output_dir=run_dir) + self._rotate_checkpoints(use_mtime=True, output_dir=run_signal_dir) if strtobool(os.getenv("FLAG_LLM_PDC", "False")) and not ("async_save" in self.args.unified_checkpoint_config): # save checkpoint_done file to ensure checkpoint is complete @@ -2473,10 +2486,23 @@ def _rotate_checkpoints(self, use_mtime=False, output_dir=None) -> None: # ignore_errors for shared disks between train nodes. shutil.rmtree(checkpoint, ignore_errors=True) - def _save(self, output_dir: Optional[str] = None, state_dict=None, merge_tensor_parallel=False): + def _save( + self, + output_dir: Optional[str] = None, + state_dict=None, + merge_tensor_parallel=False, + signal_dir: Optional[str] = None, + ): output_dir = output_dir if output_dir is not None else self.args.output_dir os.makedirs(output_dir, exist_ok=True) logger.info(f"Saving model checkpoint to {output_dir}") + + # signal_dir is used for asynchronous saving situations. + if self.args.unified_checkpoint and "async_save" in self.args.unified_checkpoint_config: + signal_dir = signal_dir if signal_dir is not None else self.args.output_signal_dir + os.makedirs(signal_dir, exist_ok=True) + logger.info(f"Saving model checkpoint finish signal to {signal_dir}") + # Save a trained model and configuration using `save_pretrained()`. # They can then be reloaded using `from_pretrained()` @@ -2486,16 +2512,15 @@ def _save(self, output_dir: Optional[str] = None, state_dict=None, merge_tensor_ and self.args.unified_checkpoint and "async_save" in self.args.unified_checkpoint_config ): - os.makedirs(self.args.logging_dir, exist_ok=True) world_size = paddle.distributed.get_world_size() save_info = { "world_size": world_size, "ignore_save_lr_and_optim": self.args.ignore_save_lr_and_optim, "skip_save_model_weight": "skip_save_model_weight" in self.args.unified_checkpoint_config, } - if os.path.exists(os.path.join(self.args.logging_dir, "async_save_info.json")): # afs cannot overwrite - os.remove(os.path.join(self.args.logging_dir, "async_save_info.json")) - with open(os.path.join(self.args.logging_dir, "async_save_info.json"), "w") as f: + if os.path.exists(os.path.join(signal_dir, "async_save_info.json")): # afs cannot overwrite + os.remove(os.path.join(signal_dir, "async_save_info.json")) + with open(os.path.join(signal_dir, "async_save_info.json"), "w") as f: json.dump(save_info, f) if self.args.should_save: @@ -2510,7 +2535,7 @@ def _save(self, output_dir: Optional[str] = None, state_dict=None, merge_tensor_ if not self.is_in_train: self.args.unified_checkpoint_config = [] - self.unified_checkpoint_handler.save_unified_checkpoint(self.model, self.optimizer, output_dir) + self.unified_checkpoint_handler.save_unified_checkpoint(self.model, self.optimizer, output_dir, signal_dir) # recover unified_checkpoint_config for not trine stage if not self.is_in_train: diff --git a/paddlenlp/trainer/trainer_utils.py b/paddlenlp/trainer/trainer_utils.py index 03f6974bbc7b..a36f76607042 100644 --- a/paddlenlp/trainer/trainer_utils.py +++ b/paddlenlp/trainer/trainer_utils.py @@ -248,7 +248,7 @@ def _check_checkpoint_files(folder_path, world_size, ignore_save_lr_and_optim, s return a -def get_last_checkpoint(folder, uc_async_save=False): +def get_last_checkpoint(folder, signal_folder=None, uc_async_save=False): content = os.listdir(folder) checkpoints = [ path @@ -258,6 +258,9 @@ def get_last_checkpoint(folder, uc_async_save=False): if len(checkpoints) == 0: return + if uc_async_save: + assert signal_folder is not None + if strtobool(os.getenv("FLAG_LLM_PDC", "False")): for i in sorted(checkpoints, key=lambda x: int(_re_checkpoint.search(x).groups()[0]), reverse=True): current_path = os.path.join(folder, i) @@ -267,11 +270,12 @@ def get_last_checkpoint(folder, uc_async_save=False): return current_path else: saving_info = paddle.load(distributed_file(os.path.join(current_path, ".saving_info"))) + current_signal_path = os.path.join(signal_folder, i) pre_world_size = saving_info.get("world_size", 1) ignore_save_lr_and_optim = saving_info.get("ignore_save_lr_and_optim", False) skip_save_model_weight = saving_info.get("skip_save_model_weight", False) if _check_checkpoint_files( - current_path, pre_world_size, ignore_save_lr_and_optim, skip_save_model_weight + current_signal_path, pre_world_size, ignore_save_lr_and_optim, skip_save_model_weight ): return current_path return diff --git a/paddlenlp/trainer/training_args.py b/paddlenlp/trainer/training_args.py index 08a04604ec6d..29ce7591e010 100644 --- a/paddlenlp/trainer/training_args.py +++ b/paddlenlp/trainer/training_args.py @@ -420,6 +420,7 @@ class TrainingArguments: }, ) logging_dir: Optional[str] = field(default=None, metadata={"help": "VisualDL log dir."}) + output_signal_dir: Optional[str] = field(default=None, metadata={"help": "Asynchronous saving signal dir."}) logging_strategy: IntervalStrategy = field( default="steps", metadata={"help": "The logging strategy to use."}, @@ -827,6 +828,10 @@ def __post_init__(self): self.logging_dir = os.path.join(self.output_dir, default_logdir()) if self.logging_dir is not None: self.logging_dir = os.path.expanduser(self.logging_dir) + if self.output_signal_dir is None and self.output_dir is not None: + self.output_signal_dir = self.output_dir + if self.output_signal_dir is not None: + self.output_signal_dir = os.path.expanduser(self.output_signal_dir) if self.disable_tqdm is None: self.disable_tqdm = False # logger.getEffectiveLevel() > logging.WARN