Skip to content

Commit

Permalink
DirectScheduler: use num_cores_per_mpiproc if defined in resources
Browse files Browse the repository at this point in the history
If `num_cores_per_mpiproc` is specified in the job resources, the value
will now be exported as the `OMP_NUM_THREADS` variable.

Co-Authored-By: Sebastiaan Huber <mail@sphuber.net>
  • Loading branch information
dev-zero and sphuber committed Sep 9, 2021
1 parent d33ae8a commit 37a9d23
Show file tree
Hide file tree
Showing 2 changed files with 63 additions and 81 deletions.
18 changes: 14 additions & 4 deletions aiida/schedulers/plugins/direct.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,19 +150,29 @@ def _get_submit_script_header(self, job_tmpl):
if job_tmpl.custom_scheduler_commands:
lines.append(job_tmpl.custom_scheduler_commands)

env_lines = []

if job_tmpl.job_resource and job_tmpl.job_resource.num_cores_per_mpiproc:
# since this was introduced after the environment injection below,
# it is intentionally put before it to avoid breaking current users script by overruling
# any explicit OMP_NUM_THREADS they may have set in their job_environment
env_lines.append(f'export OMP_NUM_THREADS={job_tmpl.job_resource.num_cores_per_mpiproc}')

# Job environment variables are to be set on one single line.
# This is a tough job due to the escaping of commas, etc.
# moreover, I am having issues making it work.
# Therefore, I assume that this is bash and export variables by
# and.

if job_tmpl.job_environment:
lines.append(empty_line)
lines.append('# ENVIRONMENT VARIABLES BEGIN ###')
if not isinstance(job_tmpl.job_environment, dict):
raise ValueError('If you provide job_environment, it must be a dictionary')
for key, value in job_tmpl.job_environment.items():
lines.append(f'export {key.strip()}={escape_for_bash(value)}')
env_lines.append(f'export {key.strip()}={escape_for_bash(value)}')

if env_lines:
lines.append(empty_line)
lines.append('# ENVIRONMENT VARIABLES BEGIN ###')
lines += env_lines
lines.append('# ENVIRONMENT VARIABLES END ###')
lines.append(empty_line)

Expand Down
126 changes: 49 additions & 77 deletions tests/schedulers/test_direct.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,96 +7,68 @@
# For further information on the license, see the LICENSE.txt file #
# For further information please visit http://www.aiida.net #
###########################################################################
# pylint: disable=invalid-name,protected-access
"""Tests for the `DirectScheduler` plugin."""
import unittest
# pylint: disable=redefined-outer-name
"""Tests for the ``DirectScheduler`` plugin."""
import pytest

from aiida.schedulers.plugins.direct import DirectScheduler
from aiida.common.datastructures import CodeInfo, CodeRunMode
from aiida.schedulers import SchedulerError
from aiida.schedulers.datastructures import JobTemplate
from aiida.schedulers.plugins.direct import DirectScheduler

# This was executed with ps -o pid,stat,user,time | tail -n +2
mac_ps_output_str = """21259 S+ broeder 0:00.04
87619 S+ broeder 0:00.44
87634 S+ broeder 0:00.01
87649 S+ broeder 0:00.02
87664 S+ broeder 0:00.01
87679 S+ broeder 0:00.01
87694 S+ broeder 0:00.01
87711 S+ broeder 0:00.01
87726 S+ broeder 0:00.02
87741 S+ broeder 0:00.01
87756 S+ broeder 0:00.01
87771 S+ broeder 0:00.01
87787 S+ broeder 0:00.02
87803 S+ broeder 0:00.01
87818 S+ broeder 0:00.02
87834 S+ broeder 0:00.02
87849 S+ broeder 0:00.11
87865 S+ broeder 0:00.02
87880 S+ broeder 0:00.02
15967 S+ broeder 0:00.05
87910 S+ broeder 0:00.02
87925 S+ broeder 0:00.02
16814 S broeder 0:00.02
24516 S+ broeder 0:00.06
"""
linux_ps_output_str = """11354 Ss aiida 00:00:00
11383 R+ aiida 00:00:00
11384 S+ aiida 00:00:00
"""

wrong_output = """aaa"""


class TestParserGetJobList(unittest.TestCase):
"""
Tests to verify if teh function _parse_joblist_output behave correctly
The tests is done parsing a string defined above, to be used offline
"""

def test_parse_mac_wrong(self):
"""
Test whether _parse_joblist can parse the qstat -f output
"""
scheduler = DirectScheduler()

with self.assertRaises(SchedulerError):
scheduler._parse_joblist_output(retval=0, stdout=wrong_output, stderr='')

def test_parse_mac_joblist_output(self):
"""
Test whether _parse_joblist can parse the qstat -f output
"""
s = DirectScheduler()
@pytest.fixture
def scheduler():
"""Return an instance of the ``DirectScheduler``."""
return DirectScheduler()

result = s._parse_joblist_output(retval=0, stdout=mac_ps_output_str, stderr='')
self.assertEqual(len(result), 24)

job_ids = [job.job_id for job in result]
self.assertIn('87849', job_ids)
@pytest.fixture
def template():
"""Return an instance of the ``JobTemplate`` with some required presets."""
code_info = CodeInfo()
code_info.cmdline_params = []

def test_parse_linux_joblist_output(self):
"""
Test whether _parse_joblist can parse the qstat -f output
"""
scheduler = DirectScheduler()
template = JobTemplate()
template.codes_info = [code_info]
template.codes_run_mode = CodeRunMode.SERIAL

result = scheduler._parse_joblist_output(retval=0, stdout=linux_ps_output_str, stderr='')
self.assertEqual(len(result), 3)
return template

job_ids = [job.job_id for job in result]
self.assertIn('11383', job_ids)

@pytest.mark.parametrize(
'stdout',
(
"""21259 S+ broeder 0:00.04\n87619 S+ broeder 0:00.44\n87634 S+ broeder 0:00.01""", # MacOS
"""11354 Ss aiida 00:00:00\n\n87619 R+ aiida 00:00:00\n11384 S+ aiida 00:00:00""", # Linux
)
)
def test_parse_joblist_output(scheduler, stdout):
"""Test the ``_parse_joblist_output`` for output taken from MacOS and Linux."""
result = scheduler._parse_joblist_output(retval=0, stdout=stdout, stderr='') # pylint: disable=protected-access
assert len(result) == 3
assert '87619' in [job.job_id for job in result]

def test_submit_script_rerunnable(aiida_caplog):
"""Test that setting the `rerunnable` option gives a warning."""
from aiida.schedulers.datastructures import JobTemplate

direct = DirectScheduler()
job_tmpl = JobTemplate()
def test_parse_joblist_output_incorrect(scheduler):
"""Test the ``_parse_joblist_output`` for invalid output."""
with pytest.raises(SchedulerError):
scheduler._parse_joblist_output(retval=0, stdout='aaa', stderr='') # pylint: disable=protected-access

job_tmpl.rerunnable = True
direct._get_submit_script_header(job_tmpl)

def test_submit_script_rerunnable(scheduler, template, aiida_caplog):
"""Test that setting the ``rerunnable`` option gives a warning."""
template.rerunnable = True
scheduler.get_submit_script(template)
assert 'rerunnable' in aiida_caplog.text
assert 'has no effect' in aiida_caplog.text


def test_submit_script_with_num_cores_per_mpiproc(scheduler, template):
"""Test that passing ``num_cores_per_mpiproc`` in job resources results in ``OMP_NUM_THREADS`` being set."""
num_cores_per_mpiproc = 24
template.job_resource = scheduler.create_job_resource(
num_machines=1, tot_num_mpiprocs=1, num_cores_per_mpiproc=num_cores_per_mpiproc
)
result = scheduler.get_submit_script(template)
assert f'export OMP_NUM_THREADS={num_cores_per_mpiproc}' in result

0 comments on commit 37a9d23

Please sign in to comment.