diff --git a/community/modules/scheduler/schedmd-slurm-gcp-v6-controller/modules/slurm_files/scripts/resume.py b/community/modules/scheduler/schedmd-slurm-gcp-v6-controller/modules/slurm_files/scripts/resume.py index e20a7ed195..ad7e902c3d 100755 --- a/community/modules/scheduler/schedmd-slurm-gcp-v6-controller/modules/slurm_files/scripts/resume.py +++ b/community/modules/scheduler/schedmd-slurm-gcp-v6-controller/modules/slurm_files/scripts/resume.py @@ -159,16 +159,12 @@ def create_instances_request(nodes, partition_name, placement_group, job_id=None # key is instance name, value overwrites properties body.perInstanceProperties = {k: per_instance_properties(k) for k in nodes} + zone_allow = nodeset.zone_policy_allow or [] + zone_deny = nodeset.zone_policy_deny or [] body.locationPolicy.locations = { - **{ - f"zones/{zone}": {"preference": "ALLOW"} - for zone in nodeset.zone_policy_allow or [] - }, - **{ - f"zones/{zone}": {"preference": "DENY"} - for zone in nodeset.zone_policy_deny or [] - }, - } + **{ f"zones/{z}": {"preference": "ALLOW"} for z in zone_allow }, + **{ f"zones/{z}": {"preference": "DENY"} for z in zone_deny }} + body.locationPolicy.targetShape = nodeset.zone_target_shape if lookup().cfg.enable_slurm_gcp_plugins: @@ -179,14 +175,15 @@ def create_instances_request(nodes, partition_name, placement_group, job_id=None request_body=body, ) - request = lookup().compute.regionInstances().bulkInsert( - project=lookup().project, region=region, body=body.to_dict() - ) + api_args = dict( + project=lookup().project, region=region, body=body.to_dict()) - if log.isEnabledFor(logging.DEBUG): - log.debug( - f"new request: endpoint={request.methodId} nodes={to_hostlist_fast(nodes)}" - ) + if len(zone_allow) == 1: # if only one zone is used, use zonal BulkInsert API, as less prone to errors + request = lookup().compute.instances().bulkInsert(**api_args, zone=zone_allow[0]) + else: + request = lookup().compute.regionInstances().bulkInsert(**api_args) + + log.debug(f"new request: endpoint={request.methodId} nodes={to_hostlist_fast(nodes)}") log_api_request(request) return request