From fba388fcf1817eca27e4d84301dd0a06a20a6311 Mon Sep 17 00:00:00 2001 From: Luke Baumann Date: Thu, 3 Oct 2024 13:47:23 -0700 Subject: [PATCH] Update Pathways-on-Cloud flags (#197) The Pathways-on-Cloud flags are changing. The latest version of the images support both the new and old versions of the flags. This PR will enable XPK to continue supporting the latest version of Pathways-on-Cloud images. After this PR, XPK will not support older versions of the Pathways-on-Cloud images. --- src/xpk/core/pathways.py | 70 +++++++++++----------------------------- 1 file changed, 18 insertions(+), 52 deletions(-) diff --git a/src/xpk/core/pathways.py b/src/xpk/core/pathways.py index 6501b13d..e6079319 100644 --- a/src/xpk/core/pathways.py +++ b/src/xpk/core/pathways.py @@ -25,11 +25,11 @@ from .system_characteristics import SystemCharacteristics PathwaysExpectedInstancesMap = { - 'v6e': 'v6e', - 'v5p': 'v5', - 'v5litepod': 'v5e', - 'v4': 'v4', - 'v3': 'v3', + 'v6e': 'tpuv6e', + 'v5p': 'tpuv5', + 'v5litepod': 'tpuv5e', + 'v4': 'tpuv4', + 'v3': 'tpuv3', } @@ -41,18 +41,9 @@ def get_pathways_worker_args(args) -> str: Returns: str: yaml containing arguments for the Pathways workers. """ - yaml = """- --alsologtostderr - - --pathways_server_port=38677 - - --pathways_resource_manager={rm_address} - - --pathways_persistent_compilation_cache=false - - --xla_tpu_enable_data_parallel_all_reduce_opt=true - - --xla_tpu_data_parallel_opt_different_sized_ops=true - - --xla_tpu_enable_async_collective_fusion=true - - --xla_tpu_enable_async_collective_fusion_fuse_all_gather=true - - --xla_tpu_enable_async_collective_fusion_multiple_steps=true - - --xla_tpu_overlap_compute_collective_tc=true - - --xla_enable_async_all_gather=true - - --pathways_tmp_dir_pattern={args.pathways_gcs_location}""" + yaml = """- --server_port=38677 + - --resource_manager_address={rm_address} + - --gcs_scratch_location={args.pathways_gcs_location}""" if args.use_pathways: return yaml.format(args=args, rm_address=get_rm_address(args)) else: @@ -67,12 +58,9 @@ def get_pathways_proxy_args(args) -> str: Returns: str: yaml containing arguments for the Pathways proxy. """ - yaml = """- --alsologtostderr - - --v=0 - - --pathways_ifrt_proxy_server_resource_manager={rm_address} - - --pathways_ifrt_proxy_server_port=38676 - - --pathways_tmp_dir_pattern={args.pathways_gcs_location} - - --pathways_plaque_network=gcp""" + yaml = """- --server_port=38676 + - --resource_manager_address={rm_address} + - --gcs_scratch_location={args.pathways_gcs_location}""" if args.use_pathways: return yaml.format(args=args, rm_address=get_rm_address(args)) @@ -210,17 +198,16 @@ def get_pathways_rm_args(args, system: SystemCharacteristics) -> str: Returns: str: yaml containing arguments for the Pathways resource manager. """ - yaml = """- --alsologtostderr - - --pathways_server_port=38677 - - --pathways_server_provides_devices=false - - --pathways_device_type=NONE - - --pathways_persistent_compilation_cache=false - - --pathways_tmp_dir_pattern={args.pathways_gcs_location} - - --pathways_expected_instances={expected_instances}""" + yaml = """- --server_port=38677 + - --gcs_scratch_location={args.pathways_gcs_location} + - --node_type=resource_manager + - --instance_count={instance_count} + - --instance_type={instance_type}""" if args.use_pathways: return yaml.format( args=args, - expected_instances=compute_pathways_expected_instances(args, system), + instance_count=args.num_slices, + instance_type=f'{get_pathways_expected_tpu_type(system.device_type)}:{system.topology}', ) else: return '' @@ -301,27 +288,6 @@ def get_proxy_address(args) -> str: return proxy_address -def compute_pathways_expected_instances( - args, system: SystemCharacteristics -) -> str: - """Computes the expected instances from the system characteristics. - Args: - args: user provided args. - system: system characteristics. - - Returns: - str: formatted string representing the expected instances (eg: - "tpuv4:2x2x2,tpuv4:2x2x2" for 2 slices of v4-16). - """ - expected_instances = ','.join([ - f'tpu{get_pathways_expected_tpu_type(system.device_type)}:{system.topology}' - for _ in range(args.num_slices) - ]) - - xpk_print(f'Pathways expected instances are: {expected_instances}') - return expected_instances - - def get_pathways_expected_tpu_type(device_type: str) -> str: """Returns the device type expected by Pathways Args: