Skip to content

Commit

Permalink
Added a function to save an input file (and not run ARC on spot)
Browse files Browse the repository at this point in the history
  • Loading branch information
alongd committed Jun 24, 2019
1 parent 2658051 commit 036a0dc
Show file tree
Hide file tree
Showing 3 changed files with 45 additions and 22 deletions.
26 changes: 26 additions & 0 deletions arc/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,3 +252,29 @@ def log_footer(execution_time, level=logging.INFO):
logger.log(level, '')
logger.log(level, 'Total execution time: {0}'.format(execution_time))
logger.log(level, 'ARC execution terminated on {0}'.format(time.asctime()))


def save_dict_file(path, restart_dict):
"""
Save an input / restart YAML file
"""
yaml.add_representer(str, string_representer)
yaml.add_representer(unicode, unicode_representer)
logger.debug('Creating a restart file...')
content = yaml.dump(data=restart_dict, encoding='utf-8', allow_unicode=True)
with open(path, 'w') as f:
f.write(content)


def string_representer(dumper, data):
"""Add a custom string representer to use block literals for multiline strings"""
if len(data.splitlines()) > 1:
return dumper.represent_scalar(tag='tag:yaml.org,2002:str', value=data, style='|')
return dumper.represent_scalar(tag='tag:yaml.org,2002:str', value=data)


def unicode_representer(dumper, data):
"""Add a custom unicode representer to use block literals for multiline strings"""
if len(data.splitlines()) > 1:
return yaml.ScalarNode(tag='tag:yaml.org,2002:str', value=data, style='|')
return yaml.ScalarNode(tag='tag:yaml.org,2002:str', value=data)
18 changes: 17 additions & 1 deletion arc/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,8 @@
import arc.rmgdb as rmgdb
from arc.settings import arc_path, default_levels_of_theory, servers, valid_chars, default_job_types
from arc.scheduler import Scheduler
from arc.common import VERSION, read_file, time_lapse, check_ess_settings, initialize_log, log_footer, get_logger
from arc.common import VERSION, read_file, time_lapse, check_ess_settings, initialize_log, log_footer, get_logger,\
save_dict_file
from arc.arc_exceptions import InputError, SettingsError, SpeciesError
from arc.species.species import ARCSpecies
from arc.reaction import ARCReaction
Expand Down Expand Up @@ -601,6 +602,21 @@ def from_dict(self, input_dict, project=None, project_directory=None):
else:
self.arc_rxn_list = list()

def write_input_file(self, path=None):
"""
Save the current attributes as an ARC input file.
Args:
path (str, unicode, optional): The full path for the generated input file.
"""
if path is None:
path = os.path.join(self.project_directory, 'input.yml')
base_path = os.path.dirname(path)
if not os.path.isdir(base_path):
os.makedirs(base_path)
logger.info('\n\nWriting input file to {0}'.format(path))
save_dict_file(path=path, restart_dict=self.restart_dict)

def execute(self):
"""Execute ARC"""
logger.info('\n')
Expand Down
23 changes: 2 additions & 21 deletions arc/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
import numpy as np
import math
import shutil
import yaml
import logging
from IPython.display import display

Expand All @@ -23,7 +22,7 @@
from rmgpy.reaction import Reaction
from rmgpy.exceptions import InputError as RMGInputError

from arc.common import get_logger
from arc.common import get_logger, save_dict_file
import arc.rmgdb as rmgdb
from arc import plotter
from arc import parser
Expand Down Expand Up @@ -1857,8 +1856,6 @@ def save_restart_dict(self):
Update the restart_dict and save the restart.yml file
"""
if self.save_restart and self.restart_dict is not None:
yaml.add_representer(str, string_representer)
yaml.add_representer(unicode, unicode_representer)
logger.debug('Creating a restart file...')
self.restart_dict['output'] = self.output
self.restart_dict['species'] = [spc.as_dict() for spc in self.species_dict.values()]
Expand All @@ -1870,10 +1867,8 @@ def save_restart_dict(self):
for job_name in self.running_jobs[spc.label] if 'conformer' not in job_name]\
+ [self.job_dict[spc.label]['conformers'][int(job_name.split('mer')[1])].as_dict()
for job_name in self.running_jobs[spc.label] if 'conformer' in job_name]
content = yaml.dump(data=self.restart_dict, encoding='utf-8', allow_unicode=True)
with open(self.restart_path, 'w') as f:
f.write(content)
logger.debug('Dumping restart dictionary:\n{0}'.format(self.restart_dict))
save_dict_file(path=self.restart_path, restart_dict=self.restart_dict)

def make_reaction_labels_info_file(self):
"""A helper function for creating the `reactions labels.info` file"""
Expand Down Expand Up @@ -1964,20 +1959,6 @@ def min_list(lst):
return min([entry for entry in lst if entry is not None])


def string_representer(dumper, data):
"""Add a custom string representer to use block literals for multiline strings"""
if len(data.splitlines()) > 1:
return dumper.represent_scalar(tag='tag:yaml.org,2002:str', value=data, style='|')
return dumper.represent_scalar(tag='tag:yaml.org,2002:str', value=data)


def unicode_representer(dumper, data):
"""Add a custom unicode representer to use block literals for multiline strings"""
if len(data.splitlines()) > 1:
return yaml.ScalarNode(tag='tag:yaml.org,2002:str', value=data, style='|')
return yaml.ScalarNode(tag='tag:yaml.org,2002:str', value=data)


def sum_time_delta(timedelta_list):
"""A helper function for summing datetime.timedelta objects"""
result = datetime.timedelta(0)
Expand Down

0 comments on commit 036a0dc

Please sign in to comment.