Skip to content

Commit

Permalink
Merge pull request #3354 from mr0re1/chunk_pp
Browse files Browse the repository at this point in the history
SlurmGCP. Improve non-exclusive placement alloaction
  • Loading branch information
mr0re1 authored Dec 7, 2024
2 parents 8b4d994 + f84c4f7 commit c519141
Show file tree
Hide file tree
Showing 5 changed files with 143 additions and 44 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,15 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import List, Optional, Dict, Collection
from typing import List, Optional
import argparse
from datetime import timedelta
import shlex
import json
import logging
import os
import yaml
from itertools import chain
import collections
from pathlib import Path
from dataclasses import dataclass

Expand All @@ -47,7 +47,7 @@

log = logging.getLogger()

PLACEMENT_MAX_CNT = 150
PLACEMENT_MAX_CNT = 1500
# Placement group needs to be the same for an entire bulk_insert hence
# if placement is used the actual BULK_INSERT_LIMIT will be
# max([1000, PLACEMENT_MAX_CNT])
Expand Down Expand Up @@ -262,7 +262,7 @@ def group_nodes_bulk(nodes: List[str], resume_data: Optional[ResumeData], lkp: u

groups[job.job_id] = []
# placement group assignment is based on all allocated nodes, ...
for pn in create_placement_groups(job.nodes_alloc, job.job_id, lkp):
for pn in create_placements(job.nodes_alloc, job.job_id, lkp):
groups[job.job_id].append(
PlacementAndNodes(
placement=pn.placement,
Expand All @@ -271,7 +271,7 @@ def group_nodes_bulk(nodes: List[str], resume_data: Optional[ResumeData], lkp: u
))
non_excl.difference_update(job.nodes_alloc)

groups[None] = create_placement_groups(sorted(non_excl), job_id=0, lkp=lkp)
groups[None] = create_placements(sorted(non_excl), excl_job_id=None, lkp=lkp)

def chunk_nodes(nodes: List[str]):
chunk_size = BULK_INSERT_LIMIT
Expand Down Expand Up @@ -483,45 +483,79 @@ def create_placement_request(pg_name, region):
return request


def create_placement_groups(nodes: List[str], job_id:int, lkp: util.Lookup) -> List[PlacementAndNodes]:
res = []
for _, ns_nodes in lkp.nodeset_map(nodes).items():
res.extend(create_nodeset_placement_groups(ns_nodes, job_id, lkp))
return res
def create_placements(nodes: List[str], excl_job_id:Optional[int], lkp: util.Lookup) -> List[PlacementAndNodes]:
nodeset_map = collections.defaultdict(list)
for node in nodes: # split nodes on nodesets
nodeset_map[lkp.node_nodeset_name(node)].append(node)

placements = []
for _, ns_nodes in nodeset_map.items():
placements.extend(create_nodeset_placements(ns_nodes, excl_job_id, lkp))
return placements

def create_nodeset_placement_groups(nodes: List[str], job_id:int, lkp: util.Lookup) -> List[PlacementAndNodes]:

def _allocate_nodes_to_placements(nodes: List[str], excl_job_id:Optional[int], lkp: util.Lookup) -> List[PlacementAndNodes]:
# canned result for no placement policies created
no_pp = [PlacementAndNodes(placement=None, nodes=nodes)]

if len(nodes) < 2:
if excl_job_id and len(nodes) < 2:
return no_pp # don't create placement_policy for just one node

model = nodes[0]
nodeset = lkp.node_nodeset(model)
if not (nodeset.enable_placement and valid_placement_node(model)):
return no_pp

if lkp.node_is_tpu(model):
return no_pp
if not (nodeset.enable_placement and valid_placement_node(model)):
return no_pp

region = lkp.node_region(model)

groups = [
PlacementAndNodes(
placement=f"{lkp.cfg.slurm_cluster_name}-slurmgcp-managed-{nodeset.nodeset_name}-{job_id}-{i}",
nodes=chunk
)
for i, chunk in enumerate(chunked(nodes, n=PLACEMENT_MAX_CNT))
name_prefix = f"{lkp.cfg.slurm_cluster_name}-slurmgcp-managed-{nodeset.nodeset_name}"
if excl_job_id: # simply chunk given nodes by max size of placement
return [
PlacementAndNodes(placement=f"{name_prefix}-{excl_job_id}-{i}", nodes=chunk)
for i, chunk in enumerate(chunked(nodes, n=PLACEMENT_MAX_CNT))
]

# split whole nodeset (not only nodes to resume) into chunks of max size of placement
# create placements (most likely already exists) placements for requested nodes
chunks = collections.defaultdict(list) # chunk_id -> nodes
invalid = []

for node in nodes:
try:
chunk = lkp.node_index(node) // PLACEMENT_MAX_CNT
chunks[chunk].append(node)
except:
invalid.append(node)

placements = [
# NOTE: use 0 instead of job_id for consistency with previous SlurmGCP behavior
PlacementAndNodes(placement=f"{name_prefix}-0-{c_id}", nodes=c_nodes)
for c_id, c_nodes in chunks.items()
]

if invalid:
placements.append(PlacementAndNodes(placement=None, nodes=invalid))
log.error(f"Could not find placement for nodes with unexpected names: {to_hostlist_fast(invalid)}")

return placements

def create_nodeset_placements(nodes: List[str], excl_job_id:Optional[int], lkp: util.Lookup) -> List[PlacementAndNodes]:
placements = _allocate_nodes_to_placements(nodes, excl_job_id, lkp)
region = lkp.node_region(nodes[0])

if log.isEnabledFor(logging.DEBUG):
debug_groups = {g.placement: to_hostlist_fast(g.nodes) for g in groups}
debug_p = {p.placement: to_hostlist_fast(p.nodes) for p in placements}
log.debug(
f"creating {len(groups)} placement groups: \n{yaml.safe_dump(debug_groups).rstrip()}"
f"creating {len(placements)} placement groups: \n{yaml.safe_dump(debug_p).rstrip()}"
)

requests = {
g.placement: create_placement_request(g.placement, region) for g in groups
p.placement: create_placement_request(p.placement, region) for p in placements if p.placement
}
if not requests:
return placements
# TODO: aggregate all requests for whole resume and execute them at once (don't limit to nodeset/job)
ops = dict(
zip(requests.keys(), map_with_futures(ensure_execute, requests.values()))
)
Expand Down Expand Up @@ -559,7 +593,7 @@ def classify_result(item):
log.info(
f"created {len(operations)} placement groups ({to_hostlist_fast(operations.keys())})"
)
return groups
return placements


def valid_placement_node(node: str) -> bool:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ class TstNodeset:
instance_template: Optional[str] = None
reservation_name: Optional[str] = ""
zone_policy_allow: Optional[list[str]] = field(default_factory=list)
enable_placement: bool = True

@dataclass
class TstPartition:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,18 +12,22 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Optional

import os
import pytest
import unittest.mock
import unittest
import tempfile

from common import TstCfg, TstNodeset, TstPartition, TstTPU # needed to import util
import util
from resume import get_resume_file_data, ResumeData, ResumeJobData, group_nodes_bulk, BulkChunk, PlacementAndNodes
import resume
from resume import ResumeData, ResumeJobData, BulkChunk, PlacementAndNodes

def test_get_resume_file_data_no_env():
with unittest.mock.patch.dict(os.environ, {"SLURM_RESUME_FILE": ""}):
assert get_resume_file_data() is None
assert resume.get_resume_file_data() is None


def test_get_resume_file_data():
Expand All @@ -49,7 +53,7 @@ def test_get_resume_file_data():
unittest.mock.patch("util.to_hostnames") as mock_to_hostnames,
):
mock_to_hostnames.return_value = ["green-0", "green-1", "green-2"]
assert get_resume_file_data() == ResumeData(jobs=[
assert resume.get_resume_file_data() == ResumeData(jobs=[
ResumeJobData(
job_id = 1,
partition="red",
Expand All @@ -60,8 +64,8 @@ def test_get_resume_file_data():


@unittest.mock.patch("util.TPU")
@unittest.mock.patch("resume.create_placement_groups")
def test_group_nodes_bulk(mock_create_placement_groups, mock_tpu):
@unittest.mock.patch("resume.create_placements")
def test_group_nodes_bulk(mock_create_placements, mock_tpu):
cfg = TstCfg(
nodeset={
"n": TstNodeset(nodeset_name="n"),
Expand All @@ -83,9 +87,9 @@ def test_group_nodes_bulk(mock_create_placement_groups, mock_tpu):
)
lkp = util.Lookup(cfg)

def mock_create_placement_groups_se(nodes, job_id, lkp):
args = (set(nodes), job_id)
if ({'c-n-1', 'c-n-2', 'c-t-8', 'c-t-9'}, 0) == args:
def mock_create_placements_se(nodes, excl_job_id, lkp):
args = (set(nodes), excl_job_id)
if ({'c-n-1', 'c-n-2', 'c-t-8', 'c-t-9'}, None) == args:
return [
PlacementAndNodes("g0", ["c-n-1", "c-n-2"]),
PlacementAndNodes(None, ['c-t-8', 'c-t-9']),
Expand All @@ -100,21 +104,21 @@ def mock_create_placement_groups_se(nodes, job_id, lkp):
PlacementAndNodes(None, ['c-t-0', 'c-t-1', 'c-t-2', 'c-t-3', 'c-t-4', 'c-t-5'])
]
raise AssertionError(f"unexpected invocation: '{args}'")
mock_create_placement_groups.side_effect = mock_create_placement_groups_se
mock_create_placements.side_effect = mock_create_placements_se

def mock_tpu_se(ns: TstNodeset) -> TstTPU:
if ns.nodeset_name == "t":
return TstTPU(vmcount=2)
raise AssertionError(f"unexpected invocation: '{ns}'")
mock_tpu.side_effect = mock_tpu_se

got = group_nodes_bulk(
got = resume.group_nodes_bulk(
["c-n-0", "c-n-1", "c-n-2", "c-t-0", "c-t-1", "c-t-2", "c-t-3", "c-t-8", "c-t-9"],
ResumeData(jobs=[
ResumeJobData(job_id=1, partition="p1", nodes_alloc=["c-n-0", "c-n-8"]),
ResumeJobData(job_id=2, partition="p2", nodes_alloc=["c-t-0", "c-t-1", "c-t-2", "c-t-3", "c-t-4", "c-t-5"]),
]), lkp)
mock_create_placement_groups.assert_called()
mock_create_placements.assert_called()
assert got == {
"c-n:jobNone:g0:0": BulkChunk(
nodes=["c-n-1", "c-n-2"], prefix="c-n", chunk_idx=0, excl_job_id=None, placement_group="g0"),
Expand All @@ -127,3 +131,43 @@ def mock_tpu_se(ns: TstNodeset) -> TstTPU:
"c-t:job2:1": BulkChunk(
nodes=["c-t-2", "c-t-3"], prefix="c-t", chunk_idx=1, excl_job_id=2, placement_group=None),
}


@pytest.mark.parametrize(
"nodes,excl_job_id,expected",
[
( # TPU - no placements
["c-t-0", "c-t-2"], 4, [PlacementAndNodes(None, ["c-t-0", "c-t-2"])]
),
( # disabled placements - no placemens
["c-x-0", "c-x-2"], 4, [PlacementAndNodes(None, ["c-x-0", "c-x-2"])]
),
( # excl_job
["c-n-0", "c-n-uno", "c-n-2", "c-n-2011"], 4, [
PlacementAndNodes("c-slurmgcp-managed-n-4-0", ["c-n-0", "c-n-uno", "c-n-2", "c-n-2011"])
]
),
( # no excl_job
["c-n-0", "c-n-uno", "c-n-2", "c-n-2011"], None, [
PlacementAndNodes("c-slurmgcp-managed-n-0-0", ["c-n-0", "c-n-2"]),
PlacementAndNodes('c-slurmgcp-managed-n-0-1', ['c-n-2011']),
PlacementAndNodes(None, ["c-n-uno"]),
]
),
],
)
def test_allocate_nodes_to_placements(nodes: list[str], excl_job_id: Optional[int], expected: list[PlacementAndNodes]):
cfg = TstCfg(
slurm_cluster_name="c",
nodeset={
"n": TstNodeset(nodeset_name="n", enable_placement=True),
"x": TstNodeset(nodeset_name="x", enable_placement=False)
},
nodeset_tpu={
"t": TstNodeset(nodeset_name="t")
})
lkp = util.Lookup(cfg)

with unittest.mock.patch("resume.valid_placement_node") as mock_valid_placement_node:
mock_valid_placement_node.return_value = True
assert resume._allocate_nodes_to_placements(nodes, excl_job_id, lkp) == expected
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,24 @@ def test_node_desc(name, expected):
assert util.lookup()._node_desc(name) == expected


@pytest.mark.parametrize(
"name,expected",
[
("az-buka-23", 23),
("az-buka-0", 0),
("az-buka", Exception),
("az-buka-xyzf", ValueError),
("az-buka-[2-3]", ValueError),
],
)
def test_node_index(name, expected):
if type(expected) is type and issubclass(expected, Exception):
with pytest.raises(expected):
util.lookup().node_index(name)
else:
assert util.lookup().node_index(name) == expected


@pytest.mark.parametrize(
"name",
[
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1579,6 +1579,14 @@ def _node_desc(self, node_name):

def node_prefix(self, node_name=None):
return self._node_desc(node_name)["prefix"]

def node_index(self, node: str) -> int:
""" node_index("cluster-nodeset-45") == 45 """
suff = self._node_desc(node)["suffix"]

if suff is None:
raise ValueError(f"Node {node} name does not end with numeric index")
return int(suff)

def node_nodeset_name(self, node_name=None):
return self._node_desc(node_name)["nodeset"]
Expand Down Expand Up @@ -1991,12 +1999,6 @@ def template_info(self, template_link):

return template

def nodeset_map(self, hostnames: list):
"""Convert a list of nodes into a map of nodeset_name to hostnames"""
nodeset_map = collections.defaultdict(list)
for node in hostnames:
nodeset_map[self.node_nodeset_name(node)].append(node)
return nodeset_map

def _parse_job_info(self, job_info: str) -> Job:
"""Extract job details"""
Expand Down

0 comments on commit c519141

Please sign in to comment.