From 036a0dc2018fe6f7eed2d509fdd55e3094f096e6 Mon Sep 17 00:00:00 2001 From: alongd Date: Mon, 24 Jun 2019 13:23:56 -0400 Subject: [PATCH] Added a function to save an input file (and not run ARC on spot) --- arc/common.py | 26 ++++++++++++++++++++++++++ arc/main.py | 18 +++++++++++++++++- arc/scheduler.py | 23 ++--------------------- 3 files changed, 45 insertions(+), 22 deletions(-) diff --git a/arc/common.py b/arc/common.py index 1580751417..a4e41b8cc7 100644 --- a/arc/common.py +++ b/arc/common.py @@ -252,3 +252,29 @@ def log_footer(execution_time, level=logging.INFO): logger.log(level, '') logger.log(level, 'Total execution time: {0}'.format(execution_time)) logger.log(level, 'ARC execution terminated on {0}'.format(time.asctime())) + + +def save_dict_file(path, restart_dict): + """ + Save an input / restart YAML file + """ + yaml.add_representer(str, string_representer) + yaml.add_representer(unicode, unicode_representer) + logger.debug('Creating a restart file...') + content = yaml.dump(data=restart_dict, encoding='utf-8', allow_unicode=True) + with open(path, 'w') as f: + f.write(content) + + +def string_representer(dumper, data): + """Add a custom string representer to use block literals for multiline strings""" + if len(data.splitlines()) > 1: + return dumper.represent_scalar(tag='tag:yaml.org,2002:str', value=data, style='|') + return dumper.represent_scalar(tag='tag:yaml.org,2002:str', value=data) + + +def unicode_representer(dumper, data): + """Add a custom unicode representer to use block literals for multiline strings""" + if len(data.splitlines()) > 1: + return yaml.ScalarNode(tag='tag:yaml.org,2002:str', value=data, style='|') + return yaml.ScalarNode(tag='tag:yaml.org,2002:str', value=data) diff --git a/arc/main.py b/arc/main.py index aa8712fe6b..b6166004b7 100644 --- a/arc/main.py +++ b/arc/main.py @@ -21,7 +21,8 @@ import arc.rmgdb as rmgdb from arc.settings import arc_path, default_levels_of_theory, servers, valid_chars, default_job_types from arc.scheduler import Scheduler -from arc.common import VERSION, read_file, time_lapse, check_ess_settings, initialize_log, log_footer, get_logger +from arc.common import VERSION, read_file, time_lapse, check_ess_settings, initialize_log, log_footer, get_logger,\ + save_dict_file from arc.arc_exceptions import InputError, SettingsError, SpeciesError from arc.species.species import ARCSpecies from arc.reaction import ARCReaction @@ -601,6 +602,21 @@ def from_dict(self, input_dict, project=None, project_directory=None): else: self.arc_rxn_list = list() + def write_input_file(self, path=None): + """ + Save the current attributes as an ARC input file. + + Args: + path (str, unicode, optional): The full path for the generated input file. + """ + if path is None: + path = os.path.join(self.project_directory, 'input.yml') + base_path = os.path.dirname(path) + if not os.path.isdir(base_path): + os.makedirs(base_path) + logger.info('\n\nWriting input file to {0}'.format(path)) + save_dict_file(path=path, restart_dict=self.restart_dict) + def execute(self): """Execute ARC""" logger.info('\n') diff --git a/arc/scheduler.py b/arc/scheduler.py index 76728c0a8f..64d4b114be 100644 --- a/arc/scheduler.py +++ b/arc/scheduler.py @@ -13,7 +13,6 @@ import numpy as np import math import shutil -import yaml import logging from IPython.display import display @@ -23,7 +22,7 @@ from rmgpy.reaction import Reaction from rmgpy.exceptions import InputError as RMGInputError -from arc.common import get_logger +from arc.common import get_logger, save_dict_file import arc.rmgdb as rmgdb from arc import plotter from arc import parser @@ -1857,8 +1856,6 @@ def save_restart_dict(self): Update the restart_dict and save the restart.yml file """ if self.save_restart and self.restart_dict is not None: - yaml.add_representer(str, string_representer) - yaml.add_representer(unicode, unicode_representer) logger.debug('Creating a restart file...') self.restart_dict['output'] = self.output self.restart_dict['species'] = [spc.as_dict() for spc in self.species_dict.values()] @@ -1870,10 +1867,8 @@ def save_restart_dict(self): for job_name in self.running_jobs[spc.label] if 'conformer' not in job_name]\ + [self.job_dict[spc.label]['conformers'][int(job_name.split('mer')[1])].as_dict() for job_name in self.running_jobs[spc.label] if 'conformer' in job_name] - content = yaml.dump(data=self.restart_dict, encoding='utf-8', allow_unicode=True) - with open(self.restart_path, 'w') as f: - f.write(content) logger.debug('Dumping restart dictionary:\n{0}'.format(self.restart_dict)) + save_dict_file(path=self.restart_path, restart_dict=self.restart_dict) def make_reaction_labels_info_file(self): """A helper function for creating the `reactions labels.info` file""" @@ -1964,20 +1959,6 @@ def min_list(lst): return min([entry for entry in lst if entry is not None]) -def string_representer(dumper, data): - """Add a custom string representer to use block literals for multiline strings""" - if len(data.splitlines()) > 1: - return dumper.represent_scalar(tag='tag:yaml.org,2002:str', value=data, style='|') - return dumper.represent_scalar(tag='tag:yaml.org,2002:str', value=data) - - -def unicode_representer(dumper, data): - """Add a custom unicode representer to use block literals for multiline strings""" - if len(data.splitlines()) > 1: - return yaml.ScalarNode(tag='tag:yaml.org,2002:str', value=data, style='|') - return yaml.ScalarNode(tag='tag:yaml.org,2002:str', value=data) - - def sum_time_delta(timedelta_list): """A helper function for summing datetime.timedelta objects""" result = datetime.timedelta(0)