From 829871473d28bd5ca7bbb6ffe15c35ce64b9b0b1 Mon Sep 17 00:00:00 2001 From: Ivan Orlov Date: Tue, 3 Dec 2024 01:05:29 +0000 Subject: [PATCH] Refactor obtaining "resume_data" and bulk-grouping nodes --- .../slurm_files/scripts/get_tpu_vmcount.py | 2 +- .../modules/slurm_files/scripts/resume.py | 325 ++++++++---------- .../scripts/slurm_gcp_plugins/__init__.py | 9 - .../slurm_gcp_plugins/test_plugin/__init__.py | 13 - .../slurm_files/scripts/tests/common.py | 6 + .../slurm_files/scripts/tests/test_resume.py | 118 +++++++ .../modules/slurm_files/scripts/util.py | 23 +- 7 files changed, 277 insertions(+), 219 deletions(-) create mode 100644 community/modules/scheduler/schedmd-slurm-gcp-v6-controller/modules/slurm_files/scripts/tests/test_resume.py diff --git a/community/modules/scheduler/schedmd-slurm-gcp-v6-controller/modules/slurm_files/scripts/get_tpu_vmcount.py b/community/modules/scheduler/schedmd-slurm-gcp-v6-controller/modules/slurm_files/scripts/get_tpu_vmcount.py index 354ec81ad3..1557d6020b 100644 --- a/community/modules/scheduler/schedmd-slurm-gcp-v6-controller/modules/slurm_files/scripts/get_tpu_vmcount.py +++ b/community/modules/scheduler/schedmd-slurm-gcp-v6-controller/modules/slurm_files/scripts/get_tpu_vmcount.py @@ -57,7 +57,7 @@ def get_vmcount_of_tpu_part(part): valid = PART_INVALID break else: - if util.part_is_tpu(part): + if util.lookup().partition_is_tpu(part): vmcount = get_vmcount_of_tpu_part(part) if vmcount == -1: valid = DIFF_VMCOUNTS_SAME_PART diff --git a/community/modules/scheduler/schedmd-slurm-gcp-v6-controller/modules/slurm_files/scripts/resume.py b/community/modules/scheduler/schedmd-slurm-gcp-v6-controller/modules/slurm_files/scripts/resume.py index 87ec84bd24..b6dc2ac14b 100755 --- a/community/modules/scheduler/schedmd-slurm-gcp-v6-controller/modules/slurm_files/scripts/resume.py +++ b/community/modules/scheduler/schedmd-slurm-gcp-v6-controller/modules/slurm_files/scripts/resume.py @@ -15,9 +15,8 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import List, Optional, Dict +from typing import List, Optional, Dict, Collection import argparse -import collections from datetime import timedelta import shlex import json @@ -26,11 +25,11 @@ import yaml from itertools import chain from pathlib import Path +from dataclasses import dataclass import util from util import ( chunked, - dirs, ensure_execute, execute_with_futures, get_insert_operations, @@ -38,20 +37,16 @@ map_with_futures, run, separate, - to_hostlist, to_hostlist_fast, trim_self_link, wait_for_operation, ) -from util import lookup, NSDict, TPU +from util import lookup, NSDict import slurm_gcp_plugins log = logging.getLogger() - -global_resume_data = None - PLACEMENT_MAX_CNT = 150 # Placement group needs to be the same for an entire bulk_insert hence # if placement is used the actual BULK_INSERT_LIMIT will be @@ -59,6 +54,35 @@ BULK_INSERT_LIMIT = 5000 +@dataclass(frozen=True) +class ResumeJobData: + job_id: int + partition: str + nodes_alloc: List[str] + +@dataclass(frozen=True) +class ResumeData: + jobs: List[ResumeJobData] + +def get_resume_file_data() -> Optional[ResumeData]: + if not (path := os.getenv("SLURM_RESUME_FILE")): + log.error("SLURM_RESUME_FILE was not in environment. Cannot get detailed job, node, partition allocation data.") + return None + blob = Path(path).read_text() + log.debug(f"Resume data: {blob}") + data = json.loads(blob) + + jobs = [] + for jo in data.get("jobs", []): + + job = ResumeJobData( + job_id = jo.get("job_id"), + partition = jo.get("partition"), + nodes_alloc = util.to_hostnames(jo.get("nodes_alloc")), + ) + jobs.append(job) + return ResumeData(jobs=jobs) + def instance_properties(nodeset:object, model:str, placement_group:Optional[str], labels:Optional[dict], job_id:Optional[int]): props = NSDict() @@ -202,113 +226,95 @@ def create_instances_request(nodes, partition_name, placement_group, job_id=None return req -def group_nodes_bulk(nodes, resume_data=None): +@dataclass(frozen=True) +class BulkChunk: + nodes: List[str] + prefix: str + chunk_idx: int + job_id: Optional[int] + partition: Optional[str] + placement_group: Optional[str] = None + + +def group_nodes_bulk(nodes: List[str], resume_data: Optional[ResumeData], lkp: util.Lookup): """group nodes by job_id, placement_group, node_group, and max bulkInsert size""" - if resume_data is None: - # all nodes will be considered jobless - jobs = {} - else: - jobs = {job.job_id: job for job in resume_data.jobs} + if resume_data is None: # all nodes will be considered jobless + resume_data = ResumeData(jobs=[]) + + nodes = set(nodes) # turn into set to simplify intersection + + @dataclass(frozen=True) + class JobGroup: # aux struct + job_id: Optional[int] + partition: Optional[str] + placement_groups: Dict[str, List[str]] + + job_groups = {} # expand all job nodelists - for job in jobs.values(): - job.nodelist_alloc = job.nodes_alloc - job.nodes_alloc = util.to_hostnames(job.nodelist_alloc) - job.nodelist_resume = job.nodes_resume - job.nodes_resume = util.to_hostnames(job.nodelist_resume) - job.tpu = util.part_is_tpu(job.partition) - if not job.tpu: + for job in resume_data.jobs: + nodes_resume = nodes & set(job.nodes_alloc) + if lkp.partition_is_tpu(job.partition): # don't create placement groups for TPU + pgs = {None: sorted(nodes_resume)} + else: # create placement groups if nodes for job need it - job.placement_groups = create_placement_groups( - node_list=job.nodes_alloc, - job_id=job.job_id, - ) + pgs = create_placement_groups(job.nodes_alloc, job.job_id) + # placement group assignment is based on all allocated nodes, but we only want to # handle nodes in nodes_resume in this run. - for pg, pg_nodes in job.placement_groups.items(): - job.placement_groups[pg] = list( - set(pg_nodes).intersection(job.nodes_resume) - ) - # a bit of a hack, but nodes resumed using scontrol instead of through job scheduling do not have a job - jobless_nodes = list( - set(nodes).difference( - chain.from_iterable(job.nodes_resume for job in jobs.values()) + for pg, pg_nodes in pgs.items(): + pgs[pg] = sorted(set(pg_nodes) & nodes_resume) + + job_groups[job.job_id] = JobGroup( + job_id=job.job_id, + partition=job.partition, + placement_groups=pgs, ) - ) - jobless_nodes_tpu = [] - for jobless_node in jobless_nodes[:]: - if lookup().node_is_tpu(jobless_node): - jobless_nodes.remove(jobless_node) - jobless_nodes_tpu.append(jobless_node) - jobs["Normal_None"] = NSDict( + all_jobless_nodes = nodes.difference( + chain.from_iterable(j.nodes_alloc for j in resume_data.jobs)) + jobless_nodes, jobless_nodes_tpu = util.separate(lkp.node_is_tpu, all_jobless_nodes) + + job_groups["Normal_None"] = JobGroup( job_id=None, - nodes_resume=jobless_nodes, - nodes_alloc=jobless_nodes, - placement_groups=create_placement_groups(node_list=jobless_nodes), + placement_groups=create_placement_groups(sorted(jobless_nodes), job_id=0), partition=None, - tpu=False, ) - jobs["TPU_None"] = NSDict( + job_groups["TPU_None"] = JobGroup( job_id=None, - nodes_resume=jobless_nodes_tpu, - nodes_alloc=jobless_nodes_tpu, + placement_groups={None: sorted(jobless_nodes_tpu)}, partition=None, - tpu=True, ) - BulkChunk = collections.namedtuple( - "BulkChunk", - ["prefix", "job_id", "partition_name", "placement_group", "nodes", "i"], - ) - BulkChunkTPU = collections.namedtuple( - "BulkChunkTPU", - ["prefix", "job_id", "partition_name", "nodes", "i"], - ) + def chunk_nodes(nodes: List[str]): + chunk_size = BULK_INSERT_LIMIT + if nodes and lkp.node_is_tpu(nodes[0]): + chunk_size = util.TPU(lkp.node_nodeset(nodes[0])).vmcount + return chunked(nodes, n=chunk_size) + grouped_nodes = [ BulkChunk( - prefix, - job_id if job_id != "Normal_None" else None, - jobs[job_id].partition, - placement_group, - chunk_nodes, - i, - ) - for job_id, job in jobs.items() - if not job.tpu + nodes=nodes_chunk, + prefix=prefix, + job_id = job.job_id, + partition = job.partition, + placement_group=placement_group, + chunk_idx=i) + + for job in job_groups.values() for placement_group, pg_nodes in job.placement_groups.items() - for prefix, nodes in util.groupby_unsorted(pg_nodes, lookup().node_prefix) - for i, chunk_nodes in enumerate(chunked(nodes, n=BULK_INSERT_LIMIT)) + for prefix, nodes in util.groupby_unsorted(pg_nodes, lkp.node_prefix) + for i, nodes_chunk in enumerate(chunk_nodes(list(nodes))) ] - grouped_nodes_tpu = [ - BulkChunkTPU( - prefix, - job_id if job_id != "TPU_None" else None, - jobs[job_id].partition, - chunk_nodes, - i, - ) - for job_id, job in jobs.items() - if job.tpu - for prefix, nodes in util.groupby_unsorted(job.nodes_resume, lookup().node_prefix) - for i, chunk_nodes in enumerate(lookup().chunk_tpu_nodes(list(nodes))) - ] - + def group_name(chunk: BulkChunk): if chunk.placement_group is not None: - return f"{chunk.prefix}:job{chunk.job_id}:{chunk.placement_group}:{chunk.i}" - if chunk.job_id is not None: - return f"{chunk.prefix}:job{chunk.job_id}:{chunk.i}" - return f"{chunk.prefix}:{chunk.i}" - - def group_name_tpu(chunk: BulkChunkTPU): + return f"{chunk.prefix}:job{chunk.job_id}:{chunk.placement_group}:{chunk.chunk_idx}" if chunk.job_id is not None: - return f"{chunk.prefix}:job{chunk.job_id}:{chunk.i}" - return f"{chunk.prefix}:{chunk.i}" + return f"{chunk.prefix}:job{chunk.job_id}:{chunk.chunk_idx}" + return f"{chunk.prefix}:{chunk.chunk_idx}" - grouped_nodes = {group_name(chunk): chunk for chunk in grouped_nodes} - grouped_nodes_tpu = {group_name_tpu(chunk): chunk for chunk in grouped_nodes_tpu} - return grouped_nodes, grouped_nodes_tpu + return {group_name(chunk): chunk for chunk in grouped_nodes} def start_tpu(data): @@ -339,55 +345,42 @@ def start_tpu(data): log.error("Error creating tpu node {node}") -def resume_nodes(nodes: List[str], resume_data=None): +def resume_nodes(nodes: List[str], resume_data: Optional[ResumeData]): """resume nodes in nodelist""" if not nodes: log.info("No nodes to resume") return - if resume_data is None and global_resume_data is not None: - resume_data = global_resume_data.deepcopy() - nodes = sorted(nodes, key=lookup().node_prefix) - grouped_nodes, grouped_tpu_nodes = group_nodes_bulk(nodes, resume_data) + grouped_nodes = group_nodes_bulk(nodes, resume_data, lookup()) if log.isEnabledFor(logging.DEBUG): - # grouped_nodelists is used in later debug logs too grouped_nodelists = { - group: to_hostlist(chunk.nodes) for group, chunk in grouped_nodes.items() - } - grouped_tpu_nodelists = { - group: to_hostlist(chunk.nodes) - for group, chunk in grouped_tpu_nodes.items() + group: to_hostlist_fast(chunk.nodes) for group, chunk in grouped_nodes.items() } log.debug( "node bulk groups: \n{}".format(yaml.safe_dump(grouped_nodelists).rstrip()) ) - log.debug( - "TPU node bulk groups: \n{}".format( - yaml.safe_dump(grouped_tpu_nodelists).rstrip() - ) - ) + tpu_start_data = [] tpu_objs = {} - for group, chunk in grouped_tpu_nodes.items(): - # do not create multiple tpu_objs if nodes with the same prefix are used - if chunk.prefix not in tpu_objs.keys(): - model = chunk.nodes[0] - tpu_objs[chunk.prefix] = TPU(lookup().node_nodeset(model)) - - tpu_start_data.append({"tpu": tpu_objs[chunk.prefix], "node": chunk.nodes}) - - # make all bulkInsert requests and execute with batch - inserts = { - group: create_instances_request( - chunk.nodes, chunk.partition_name, chunk.placement_group, chunk.job_id - ) - for group, chunk in grouped_nodes.items() - } + bi_inserts = {} + + for group, chunk in grouped_nodes.items(): + if chunk.partition and lookup().partition_is_tpu(chunk.partition): + # do not create multiple tpu_objs if nodes with the same prefix are used + if chunk.prefix not in tpu_objs.keys(): + model = chunk.nodes[0] + tpu_objs[chunk.prefix] = util.TPU(lookup().node_nodeset(model)) + tpu_start_data.append({"tpu": tpu_objs[chunk.prefix], "node": chunk.nodes}) + else: + bi_inserts[group] = create_instances_request( + chunk.nodes, chunk.partition, chunk.placement_group, chunk.job_id + ) + # execute all bulkInsert requests with batch bulk_ops = dict( - zip(inserts.keys(), map_with_futures(ensure_execute, inserts.values())) + zip(bi_inserts.keys(), map_with_futures(ensure_execute, bi_inserts.values())) ) log.debug(f"bulk_ops={yaml.safe_dump(bulk_ops)}") started = { @@ -400,7 +393,7 @@ def resume_nodes(nodes: List[str], resume_data=None): failed_reqs = [str(e) for e in failed.items()] log.error("bulkInsert API failures: {}".format("; ".join(failed_reqs))) for ident, exc in failed.items(): - down_nodes(grouped_nodes[ident].nodes, f"GCP Error: {exc._get_reason()}") + down_nodes_notify_jobs(grouped_nodes[ident].nodes, f"GCP Error: {exc._get_reason()}", resume_data) if log.isEnabledFor(logging.DEBUG): for group, op in started.items(): @@ -449,7 +442,7 @@ def resume_nodes(nodes: List[str], resume_data=None): for err in failed_op["error"]["errors"] ) if code != "RESOURCE_ALREADY_EXISTS": - down_nodes(hostlist, f"GCP Error: {msg}") + down_nodes_notify_jobs(failed_nodes, f"GCP Error: {msg}", resume_data) log.error( f"errors from insert for node '{failed_node}' ({failed_op['name']}): {msg}" ) @@ -461,33 +454,25 @@ def resume_nodes(nodes: List[str], resume_data=None): all_successful_inserts.extend(successful_inserts) -def update_job_comment(nodelist: str, comment: str): - if global_resume_data is None: - log.warning( - "Cannot update and notify jobs with API failures as no valid resume file is present." - ) - return - - nodes = util.to_hostnames(nodelist) - job_list = ( - job - for job in global_resume_data.jobs - if any(map(lambda node: node in nodes, util.to_hostnames(job.nodelist_resume))) - ) - for job in job_list: - run(f"{lookup().scontrol} update jobid={job.job_id} admincomment='{comment}'") - run(f"{lookup().scontrol} notify {job.job_id} '{comment}'") - - -def down_nodes(nodelist, reason): +def down_nodes_notify_jobs(nodes: List[str], reason: str, resume_data: Optional[ResumeData]) -> None: """set nodes down with reason""" - if isinstance(nodelist, list): - nodelist = util.to_hostlist(nodelist) - update_job_comment(nodelist, reason) + nodelist = util.to_hostlist_fast(nodes) reason_quoted = shlex.quote(reason) + log.error(f"Marking nodes {nodelist} as DOWN, reason: {reason}") run(f"{lookup().scontrol} update nodename={nodelist} state=down reason={reason_quoted}") + if resume_data is None: + log.warning("Cannot update and notify jobs with API failures as no valid resume file is present.") + return + + nodes = set(nodes) # turn into set to speed up intersection + for job in resume_data.jobs: + if not (set(job.nodes_alloc) & nodes): + continue + run(f"{lookup().scontrol} update jobid={job.job_id} admincomment='{reason_quoted}'") + run(f"{lookup().scontrol} notify {job.job_id} '{reason_quoted}'") + def hold_job(job_id, reason): """hold job, set comment to reason""" @@ -514,7 +499,7 @@ def create_placement_request(pg_name, region): return request -def create_placement_groups(node_list: List[str], job_id:int=0) -> Dict[str, List[str]]: +def create_placement_groups(node_list: List[str], job_id:int) -> Dict[str, List[str]]: pgs = {} node_map = lookup().nodeset_map(node_list) for _, nodes in node_map.items(): @@ -603,52 +588,28 @@ def valid_placement_nodes(nodelist): return True -def get_resume_file_data(): - SLURM_RESUME_FILE = os.getenv("SLURM_RESUME_FILE") - if SLURM_RESUME_FILE is None: - log.warning( - "SLURM_RESUME_FILE was not in environment. Cannot get detailed job, node, partition allocation data." - ) - return None - resume_file = Path(SLURM_RESUME_FILE) - resume_json = resume_file.read_text() - if log.isEnabledFor(logging.DEBUG): - (dirs.scripts / "resume_data.json").write_text(resume_json) - return NSDict(json.loads(resume_json)) - - -def main(nodelist): +def main(nodelist: str) -> None: """main called when run as script""" log.debug(f"ResumeProgram {nodelist}") # Filter out nodes not in config.yaml - other_nodes, pm_nodes = separate( + other_nodes, nodes = separate( lookup().is_power_managed_node, util.to_hostnames(nodelist) ) if other_nodes: - log.debug( + log.error( f"Ignoring non-power-managed nodes '{to_hostlist_fast(other_nodes)}' from '{nodelist}'" ) - pm_nodelist = util.to_hostlist_fast(pm_nodes) - if pm_nodes: - log.debug(f"Resuming nodes '{pm_nodelist}' from '{nodelist}'") - else: - log.debug("No nodes to resume") + if not nodes: + log.info("No nodes to resume") return - log.info(f"resume {pm_nodelist}") - resume_nodes(pm_nodes, global_resume_data) - # TODO only run below if resume_nodes succeeds but - # resume_nodes does not currently return any status. - if lookup().cfg.enable_slurm_gcp_plugins: - slurm_gcp_plugins.post_main_resume_nodes( - lkp=lookup(), nodelist=nodelist, global_resume_data=global_resume_data - ) - + resume_data = get_resume_file_data() + log.info(f"resume {util.to_hostlist_fast(nodes)}") + resume_nodes(nodes, resume_data) + if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("nodelist", help="list of nodes to resume") args = util.init_log_and_parse(parser) - - global_resume_data = get_resume_file_data() main(args.nodelist) diff --git a/community/modules/scheduler/schedmd-slurm-gcp-v6-controller/modules/slurm_files/scripts/slurm_gcp_plugins/__init__.py b/community/modules/scheduler/schedmd-slurm-gcp-v6-controller/modules/slurm_files/scripts/slurm_gcp_plugins/__init__.py index c56793c4be..dec7085994 100644 --- a/community/modules/scheduler/schedmd-slurm-gcp-v6-controller/modules/slurm_files/scripts/slurm_gcp_plugins/__init__.py +++ b/community/modules/scheduler/schedmd-slurm-gcp-v6-controller/modules/slurm_files/scripts/slurm_gcp_plugins/__init__.py @@ -100,14 +100,6 @@ def register_instance_information_fields(*pos_args, **keyword_args): ) -# Called just after VM instances have been created and are up -def post_main_resume_nodes(*pos_args, **keyword_args): - run_plugins_for_function( - plugin_function_name="post_main_resume_nodes", - pos_args=pos_args, - keyword_args=keyword_args, - ) - # Called just before VM instances are deleted should be still up # (NOTE: if a node has failed it might not be up or unresponsive) @@ -141,7 +133,6 @@ def pre_placement_group_insert(*pos_args, **keyword_args): __all__ = [ - "post_main_resume_nodes", "pre_main_suspend_nodes", "register_instance_information_fields", "pre_instance_bulk_insert", diff --git a/community/modules/scheduler/schedmd-slurm-gcp-v6-controller/modules/slurm_files/scripts/slurm_gcp_plugins/test_plugin/__init__.py b/community/modules/scheduler/schedmd-slurm-gcp-v6-controller/modules/slurm_files/scripts/slurm_gcp_plugins/test_plugin/__init__.py index 67dbd5d408..b4b3be580d 100644 --- a/community/modules/scheduler/schedmd-slurm-gcp-v6-controller/modules/slurm_files/scripts/slurm_gcp_plugins/test_plugin/__init__.py +++ b/community/modules/scheduler/schedmd-slurm-gcp-v6-controller/modules/slurm_files/scripts/slurm_gcp_plugins/test_plugin/__init__.py @@ -22,20 +22,7 @@ def register_instance_information_fields(*pos_args, **keyword_args): keyword_args["instance_information_fields"].extend(instance_information_fields) -def post_main_resume_nodes(*pos_args, **keyword_args): - logging.debug("post_main_resume_nodes called from test_plugin") - for node in keyword_args["nodelist"]: - logging.info( - ( - "test_plugin:" - + f"nodename:{node} " - + f"instance_id:{keyword_args['lkp'].instance(node)['id']} " - + f"physicalHost:{keyword_args['lkp'].instance(node)['resourceStatus']['physicalHost']}" - ) - ) - __all__ = [ "register_instance_information_fields", - "post_main_resume_nodes", ] diff --git a/community/modules/scheduler/schedmd-slurm-gcp-v6-controller/modules/slurm_files/scripts/tests/common.py b/community/modules/scheduler/schedmd-slurm-gcp-v6-controller/modules/slurm_files/scripts/tests/common.py index 2272aeef99..bfe7f5cc9c 100644 --- a/community/modules/scheduler/schedmd-slurm-gcp-v6-controller/modules/slurm_files/scripts/tests/common.py +++ b/community/modules/scheduler/schedmd-slurm-gcp-v6-controller/modules/slurm_files/scripts/tests/common.py @@ -37,6 +37,12 @@ class TstNodeset: reservation_name: Optional[str] = "" zone_policy_allow: Optional[list[str]] = field(default_factory=list) +@dataclass +class TstPartition: + partition_name: str = "euler" + partition_nodeset: list[str] = field(default_factory=list) + partition_nodeset_tpu: list[str] = field(default_factory=list) + @dataclass class TstCfg: slurm_cluster_name: str = "m22" diff --git a/community/modules/scheduler/schedmd-slurm-gcp-v6-controller/modules/slurm_files/scripts/tests/test_resume.py b/community/modules/scheduler/schedmd-slurm-gcp-v6-controller/modules/slurm_files/scripts/tests/test_resume.py new file mode 100644 index 0000000000..7d9dfe4ac1 --- /dev/null +++ b/community/modules/scheduler/schedmd-slurm-gcp-v6-controller/modules/slurm_files/scripts/tests/test_resume.py @@ -0,0 +1,118 @@ +# Copyright 2024 "Google LLC" +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import unittest.mock +import unittest +import tempfile + +from common import TstCfg, TstNodeset, TstPartition, TstTPU # needed to import util +import util +from resume import get_resume_file_data, ResumeData, ResumeJobData, group_nodes_bulk, BulkChunk + +def test_get_resume_file_data_no_env(): + with unittest.mock.patch.dict(os.environ, {"SLURM_RESUME_FILE": ""}): + assert get_resume_file_data() is None + + +def test_get_resume_file_data(): + with tempfile.NamedTemporaryFile() as f: + f.write(b"""{ + "jobs": [ + { + "extra": null, + "job_id": 1, + "features": null, + "nodes_alloc": "green-[0-2]", + "nodes_resume": "green-[0-1]", + "oversubscribe": "OK", + "partition": "red", + "reservation": null + } + ], + "all_nodes_resume": "green-[0-1]" +}""") + f.flush() + with ( + unittest.mock.patch.dict(os.environ, {"SLURM_RESUME_FILE": f.name}), + unittest.mock.patch("util.to_hostnames") as mock_to_hostnames, + ): + mock_to_hostnames.return_value = ["green-0", "green-1", "green-2"] + assert get_resume_file_data() == ResumeData(jobs=[ + ResumeJobData( + job_id = 1, + partition="red", + nodes_alloc=["green-0", "green-1", "green-2"], + ) + ]) + mock_to_hostnames.assert_called_once_with("green-[0-2]") + + +@unittest.mock.patch("util.TPU") +@unittest.mock.patch("resume.create_placement_groups") +def test_group_nodes_bulk(mock_create_placement_groups, mock_tpu): + cfg = TstCfg( + nodeset={ + "n": TstNodeset(nodeset_name="n"), + }, + nodeset_tpu={ + "t": TstNodeset(nodeset_name="t"), + }, + partitions={ + "p1": TstPartition(partition_name="p1"), + "p2": TstPartition( + partition_name="p2", + partition_nodeset_tpu=["t"], + ) + } + ) + lkp = util.Lookup(cfg) + + def mock_create_placement_groups_se(nodes, job_id): + args = (set(nodes), job_id) + if ({"c-n-1", "c-n-2"}, 0) == args: + return { "g0": ["c-n-1", "c-n-2"] } + if ({"c-n-0", "c-n-8"}, 1) == args: + return { + "g10": ["c-n-0"], + "g11": ["c-n-8"], + } + raise AssertionError(f"unexpected invocation: '{args}'") + mock_create_placement_groups.side_effect = mock_create_placement_groups_se + + def mock_tpu_se(ns: TstNodeset) -> TstTPU: + if ns.nodeset_name == "t": + return TstTPU(vmcount=2) + raise AssertionError(f"unexpected invocation: '{ns}'") + mock_tpu.side_effect = mock_tpu_se + + got = group_nodes_bulk( + ["c-n-0", "c-n-1", "c-n-2", "c-t-0", "c-t-1", "c-t-2", "c-t-3", "c-t-8", "c-t-9"], + ResumeData(jobs=[ + ResumeJobData(job_id=1, partition="p1", nodes_alloc=["c-n-0", "c-n-8"]), + ResumeJobData(job_id=2, partition="p2", nodes_alloc=["c-t-0", "c-t-1", "c-t-2", "c-t-3", "c-t-4", "c-t-5"]), + ]), lkp) + mock_create_placement_groups.assert_called() + assert got == { + "c-n:jobNone:g0:0": BulkChunk( + nodes=["c-n-1", "c-n-2"], prefix="c-n", chunk_idx=0, job_id=None, partition=None, placement_group="g0"), + "c-n:job1:g10:0": BulkChunk( + nodes=["c-n-0"], prefix="c-n", chunk_idx=0, job_id=1, partition="p1", placement_group="g10"), + "c-t:0": BulkChunk( + nodes=["c-t-8", "c-t-9"], prefix="c-t", chunk_idx=0, job_id=None, partition=None, placement_group=None), + "c-t:job2:0": BulkChunk( + nodes=["c-t-0", "c-t-1"], prefix="c-t", chunk_idx=0, job_id=2, partition="p2", placement_group=None), + "c-t:job2:1": BulkChunk( + nodes=["c-t-2", "c-t-3"], prefix="c-t", chunk_idx=1, job_id=2, partition="p2", placement_group=None), + } diff --git a/community/modules/scheduler/schedmd-slurm-gcp-v6-controller/modules/slurm_files/scripts/util.py b/community/modules/scheduler/schedmd-slurm-gcp-v6-controller/modules/slurm_files/scripts/util.py index 1d07678619..bbe69f24dd 100755 --- a/community/modules/scheduler/schedmd-slurm-gcp-v6-controller/modules/slurm_files/scripts/util.py +++ b/community/modules/scheduler/schedmd-slurm-gcp-v6-controller/modules/slurm_files/scripts/util.py @@ -947,11 +947,7 @@ def cur_repr(): res.append(f"{p}[{','.join(cs)}]") return ",".join(res) - -def part_is_tpu(part): - """check if partition with name part contains a nodeset of type tpu""" - return len(lookup().cfg.partitions[part].partition_nodeset_tpu) > 0 - +@lru_cache(maxsize=None) def to_hostnames(nodelist: str) -> List[str]: """make list of hostnames from hostlist expression""" if not nodelist: @@ -1570,10 +1566,14 @@ def node_nodeset_name(self, node_name=None): def node_nodeset(self, node_name=None): nodeset_name = self.node_nodeset_name(node_name) - ns = self.cfg.nodeset.get(nodeset_name) - if ns: - return ns - return self.cfg.nodeset_tpu.get(nodeset_name) + if nodeset_name in self.cfg.nodeset_tpu: + return self.cfg.nodeset_tpu[nodeset_name] + return self.cfg.nodeset[nodeset_name] + + def partition_is_tpu(self, part: str) -> bool: + """check if partition with name part contains a nodeset of type tpu""" + return len(self.cfg.partitions[part].partition_nodeset_tpu) > 0 + def node_is_tpu(self, node_name=None): nodeset_name = self.node_nodeset_name(node_name) @@ -1583,11 +1583,6 @@ def node_is_dyn(self, node_name=None) -> bool: nodeset = self.node_nodeset_name(node_name) return self.cfg.nodeset_dyn.get(nodeset) is not None - def chunk_tpu_nodes(self, tpu_nodes): - model = tpu_nodes[0] - tpu = TPU(self.node_nodeset(model)) - return chunked(tpu_nodes, n=tpu.vmcount) - def node_template(self, node_name=None): return self.node_nodeset(node_name).instance_template