Skip to content

Commit

Permalink
Merge branch 'main' into bump-kueue-version
Browse files Browse the repository at this point in the history
  • Loading branch information
PBundyra authored Oct 4, 2024
2 parents 8401d45 + fba388f commit 2aa0707
Showing 1 changed file with 18 additions and 52 deletions.
70 changes: 18 additions & 52 deletions src/xpk/core/pathways.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,11 @@
from .system_characteristics import SystemCharacteristics

PathwaysExpectedInstancesMap = {
'v6e': 'v6e',
'v5p': 'v5',
'v5litepod': 'v5e',
'v4': 'v4',
'v3': 'v3',
'v6e': 'tpuv6e',
'v5p': 'tpuv5',
'v5litepod': 'tpuv5e',
'v4': 'tpuv4',
'v3': 'tpuv3',
}


Expand All @@ -41,18 +41,9 @@ def get_pathways_worker_args(args) -> str:
Returns:
str: yaml containing arguments for the Pathways workers.
"""
yaml = """- --alsologtostderr
- --pathways_server_port=38677
- --pathways_resource_manager={rm_address}
- --pathways_persistent_compilation_cache=false
- --xla_tpu_enable_data_parallel_all_reduce_opt=true
- --xla_tpu_data_parallel_opt_different_sized_ops=true
- --xla_tpu_enable_async_collective_fusion=true
- --xla_tpu_enable_async_collective_fusion_fuse_all_gather=true
- --xla_tpu_enable_async_collective_fusion_multiple_steps=true
- --xla_tpu_overlap_compute_collective_tc=true
- --xla_enable_async_all_gather=true
- --pathways_tmp_dir_pattern={args.pathways_gcs_location}"""
yaml = """- --server_port=38677
- --resource_manager_address={rm_address}
- --gcs_scratch_location={args.pathways_gcs_location}"""
if args.use_pathways:
return yaml.format(args=args, rm_address=get_rm_address(args))
else:
Expand All @@ -67,12 +58,9 @@ def get_pathways_proxy_args(args) -> str:
Returns:
str: yaml containing arguments for the Pathways proxy.
"""
yaml = """- --alsologtostderr
- --v=0
- --pathways_ifrt_proxy_server_resource_manager={rm_address}
- --pathways_ifrt_proxy_server_port=38676
- --pathways_tmp_dir_pattern={args.pathways_gcs_location}
- --pathways_plaque_network=gcp"""
yaml = """- --server_port=38676
- --resource_manager_address={rm_address}
- --gcs_scratch_location={args.pathways_gcs_location}"""

if args.use_pathways:
return yaml.format(args=args, rm_address=get_rm_address(args))
Expand Down Expand Up @@ -210,17 +198,16 @@ def get_pathways_rm_args(args, system: SystemCharacteristics) -> str:
Returns:
str: yaml containing arguments for the Pathways resource manager.
"""
yaml = """- --alsologtostderr
- --pathways_server_port=38677
- --pathways_server_provides_devices=false
- --pathways_device_type=NONE
- --pathways_persistent_compilation_cache=false
- --pathways_tmp_dir_pattern={args.pathways_gcs_location}
- --pathways_expected_instances={expected_instances}"""
yaml = """- --server_port=38677
- --gcs_scratch_location={args.pathways_gcs_location}
- --node_type=resource_manager
- --instance_count={instance_count}
- --instance_type={instance_type}"""
if args.use_pathways:
return yaml.format(
args=args,
expected_instances=compute_pathways_expected_instances(args, system),
instance_count=args.num_slices,
instance_type=f'{get_pathways_expected_tpu_type(system.device_type)}:{system.topology}',
)
else:
return ''
Expand Down Expand Up @@ -301,27 +288,6 @@ def get_proxy_address(args) -> str:
return proxy_address


def compute_pathways_expected_instances(
args, system: SystemCharacteristics
) -> str:
"""Computes the expected instances from the system characteristics.
Args:
args: user provided args.
system: system characteristics.
Returns:
str: formatted string representing the expected instances (eg:
"tpuv4:2x2x2,tpuv4:2x2x2" for 2 slices of v4-16).
"""
expected_instances = ','.join([
f'tpu{get_pathways_expected_tpu_type(system.device_type)}:{system.topology}'
for _ in range(args.num_slices)
])

xpk_print(f'Pathways expected instances are: {expected_instances}')
return expected_instances


def get_pathways_expected_tpu_type(device_type: str) -> str:
"""Returns the device type expected by Pathways
Args:
Expand Down

0 comments on commit 2aa0707

Please sign in to comment.