Skip to content

Commit

Permalink
Refactor SlurmSystem methods for bug fix and add unit tests
Browse files Browse the repository at this point in the history
  • Loading branch information
TaekyungHeo committed May 16, 2024
1 parent 67762e4 commit 0779737
Show file tree
Hide file tree
Showing 2 changed files with 131 additions and 59 deletions.
126 changes: 82 additions & 44 deletions src/cloudai/schema/system/slurm/slurm_system.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import getpass
import logging
import re
from typing import Any, Dict, List, Optional
from typing import Any, Dict, List, Optional, Tuple

from cloudai.schema.core import System
from cloudai.util import CommandShell
Expand Down Expand Up @@ -493,61 +493,99 @@ def scancel(self, job_id: int) -> None:
"""
self.cmd_shell.execute(f"scancel {job_id}")

def update_node_states(self) -> None: # noqa: C901
def update_node_states(self, squeue_output: Optional[str] = None, sinfo_output: Optional[str] = None) -> None:
"""
Updates the states of nodes in the Slurm system by querying the current
state of each node using the 'sinfo' command, and correlates this with
'squeue' to determine which user is running jobs on each node. The method
parses the output of these commands, identifies the state of nodes and the
Updates the states of nodes in the Slurm system by querying the current state of each node using
the 'sinfo' command, and correlates this with 'squeue' to determine which user is running jobs on
each node. This method parses the output of these commands, identifies the state of nodes and the
users, and updates the corresponding SlurmNode instances in the system.
This method does not return any value. It updates the internal state of
SlurmNode instances based on the current state and user information
reported by 'sinfo' and 'squeue'.
Args:
squeue_output (str): The output from the squeue command, if already fetched.
sinfo_output (str): The output from the sinfo command, if already fetched.
"""
node_user_map = {}
squeue_command = "squeue -o '%N|%u' --noheader"
self.logger.debug(f"Updating node user information with command: {squeue_command}")
squeue_stdout, squeue_stderr = self.cmd_shell.execute(squeue_command).communicate()
if squeue_stderr:
self.logger.error(f"Error querying node user information: {squeue_stderr}")
return

for line in squeue_stdout.split("\n"):
if line.strip():
node_list, user = line.split("|")
for node in self.parse_node_list([node_list]):
node_user_map[node] = user
if squeue_output is None or sinfo_output is None:
squeue_output, _ = self.fetch_command_output("squeue -o '%N|%u' --noheader")
sinfo_output, _ = self.fetch_command_output("sinfo")

node_user_map = self.parse_squeue_output(squeue_output)
self.parse_sinfo_output(sinfo_output, node_user_map)

def fetch_command_output(self, command: str) -> Tuple[str, str]:
"""
Execute a system command and return its output.
Args:
command (str): The command to execute.
command = "sinfo"
self.logger.debug(f"Updating node states with command: {command}")
Returns:
Tuple[str, str]: The stdout and stderr from the command execution.
"""
self.logger.debug(f"Executing command: {command}")
stdout, stderr = self.cmd_shell.execute(command).communicate()
if stderr:
self.logger.error(f"Error querying node states: {stderr}")
return
self.logger.error(f"Error executing command '{command}': {stderr}")
return stdout, stderr

def parse_squeue_output(self, squeue_output: str) -> Dict[str, str]:
"""
Parse the output from the 'squeue' command to map nodes to users.
The expected format of squeue_output is lines of 'node_spec|user', where
node_spec can include comma-separated node names or ranges.
Args:
squeue_output (str): The raw output from the squeue command.
Returns:
Dict[str, str]: A dictionary mapping node names to usernames.
"""
node_user_map = {}
for line in squeue_output.split("\n"):
if line.strip():
# Split the line into node list and user, handling only the first '|'
parts = line.split("|")
if len(parts) < 2:
continue # Skip malformed lines

node_list_part, user = parts[0], "|".join(parts[1:])
# Handle cases where multiple node groups or ranges are specified
node_groups = node_list_part.split(",")
for node_group in node_groups:
# Process each node or range using parse_node_list
for node in self.parse_node_list([node_group.strip()]):
node_user_map[node] = user.strip()

# Parsing the output of 'sinfo' to update node states
for line in stdout.split("\n")[1:]: # Skip the header line
return node_user_map

def parse_sinfo_output(self, sinfo_output: str, node_user_map: Dict[str, str]) -> None:
"""
Parse the output from the 'sinfo' command to update node states.
Args:
sinfo_output (str): The output from the sinfo command.
node_user_map (dict): A dictionary mapping node names to users.
"""
for line in sinfo_output.split("\n")[1:]: # Skip the header line
if not line.strip():
continue # Skip empty lines
continue
parts = line.split()
partition, _, _, _, state, nodelist = parts[:6]
partition = partition.rstrip("*")
node_names = self.parse_node_list([nodelist])

# Convert state to enum, handling states with suffixes
state_enum = self.convert_state_to_enum(state)

for node_name in node_names:
# Find the partition and node to update the state
for part_name, nodes in self.partitions.items():
if part_name != partition:
continue
for node in nodes:
if node.name == node_name:
node.state = state_enum
node.user = node_user_map.get(node_name, "N/A")
break

node_groups = nodelist.split(",")
for node_group in node_groups:
node_names = self.parse_node_list([node_group.strip()])
state_enum = self.convert_state_to_enum(state)

for node_name in node_names:
for part_name, nodes in self.partitions.items():
if part_name != partition:
continue
for node in nodes:
if node.name == node_name:
node.state = state_enum
node.user = node_user_map.get(node_name, "N/A")

def convert_state_to_enum(self, state_str: str) -> SlurmNodeState:
"""
Expand Down
64 changes: 49 additions & 15 deletions tests/schema/system/slurm/test_slurm_system.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,56 @@
import pytest
from cloudai.schema.system import SlurmSystem
from cloudai.schema.system.slurm import SlurmNode, SlurmNodeState


def test_parse_node_list_single() -> None:
"""Test parsing a list with single node names."""
node_list = ["node1", "node2"]
expected = ["node1", "node2"]
assert SlurmSystem.parse_node_list(node_list) == expected
@pytest.fixture
def slurm_system():
nodes = [
SlurmNode(name="nodeA001", partition="main", state=SlurmNodeState.UNKNOWN_STATE),
SlurmNode(name="nodeB001", partition="main", state=SlurmNodeState.UNKNOWN_STATE),
]
system = SlurmSystem(
name="test_system",
install_path="/fake/path",
output_path="/fake/output",
default_partition="main",
partitions={"main": nodes},
)
return system


def test_parse_node_list_range() -> None:
"""Test parsing a list with a range of node names."""
node_list = ["node[1-3]", "node5"]
expected = ["node1", "node2", "node3", "node5"]
assert SlurmSystem.parse_node_list(node_list) == expected
def test_parse_squeue_output(slurm_system):
squeue_output = "nodeA001|root\nnodeA002|user"
expected_map = {"nodeA001": "root", "nodeA002": "user"}
result = slurm_system.parse_squeue_output(squeue_output)
assert result == expected_map


def test_parse_node_list_zero_padding() -> None:
"""Test parsing a list with zero-padded node ranges."""
node_list = ["node[001-003]", "node005"]
expected = ["node001", "node002", "node003", "node005"]
assert SlurmSystem.parse_node_list(node_list) == expected
def test_parse_squeue_output_with_node_ranges_and_root_user(slurm_system):
squeue_output = "nodeA[001-008],nodeB[001-008]|root"
user_map = slurm_system.parse_squeue_output(squeue_output)

expected_nodes = [f"nodeA{str(i).zfill(3)}" for i in range(1, 9)] + [f"nodeB{str(i).zfill(3)}" for i in range(1, 9)]
expected_map = {node: "root" for node in expected_nodes}

assert user_map == expected_map, "All nodes should be mapped to 'root'"


def test_parse_sinfo_output(slurm_system):
sinfo_output = (
"PARTITION AVAIL TIMELIMIT NODES STATE NODELIST\n"
"main up infinite 1 idle nodeA001\n"
"main up infinite 1 idle nodeB001"
)
node_user_map = {"nodeA001": "root", "nodeB001": "user"}
slurm_system.parse_sinfo_output(sinfo_output, node_user_map)
assert slurm_system.partitions["main"][0].state == SlurmNodeState.IDLE
assert slurm_system.partitions["main"][1].state == SlurmNodeState.IDLE


def test_update_node_states_with_mocked_outputs(slurm_system):
squeue_output = "nodeA001|root"
sinfo_output = "PARTITION AVAIL TIMELIMIT NODES STATE NODELIST\nmain up infinite 1 idle nodeA001"
slurm_system.update_node_states(squeue_output=squeue_output, sinfo_output=sinfo_output)
assert slurm_system.partitions["main"][0].state == SlurmNodeState.IDLE
assert slurm_system.partitions["main"][0].user == "root"

0 comments on commit 0779737

Please sign in to comment.