diff --git a/examples/arm/aot_arm_compiler.py b/examples/arm/aot_arm_compiler.py index b0597ed1cf..04bdb3f3cf 100644 --- a/examples/arm/aot_arm_compiler.py +++ b/examples/arm/aot_arm_compiler.py @@ -153,6 +153,37 @@ def forward(self, x): "TOSA", ] + +def get_compile_spec(target: str) -> ArmCompileSpecBuilder: + if target == "TOSA": + return ( + ArmCompileSpecBuilder().tosa_compile_spec().set_permute_memory_format(True) + ) + elif target == "ethos-u55-128": + return ( + ArmCompileSpecBuilder() + .ethosu_compile_spec( + "ethos-u55-128", + system_config="Ethos_U55_High_End_Embedded", + memory_mode="Shared_Sram", + extra_flags="--debug-force-regor --output-format=raw", + ) + .set_permute_memory_format(args.model_name in MODEL_NAME_TO_MODEL.keys()) + .set_quantize_io(True) + ) + elif target == "ethos-u85-128": + return ( + ArmCompileSpecBuilder() + .ethosu_compile_spec( + "ethos-u85-128", + system_config="Ethos_U85_SYS_DRAM_Mid", + memory_mode="Shared_Sram", + extra_flags="--output-format=raw", + ) + .set_permute_memory_format(True) + ) + + if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument( @@ -175,6 +206,7 @@ def forward(self, x): action="store", required=False, default="ethos-u55-128", + choices=targets, help=f"For ArmBackend delegated models, pick the target, and therefore the instruction set generated. valid targets are {targets}", ) parser.add_argument( @@ -256,39 +288,7 @@ def forward(self, x): # be specified. compile_spec = None if args.delegate is True: - if args.target == "TOSA": - compile_spec = ( - ArmCompileSpecBuilder() - .tosa_compile_spec() - .set_permute_memory_format(True) - ) - elif args.target == "ethos-u55-128": - compile_spec = ( - ArmCompileSpecBuilder() - .ethosu_compile_spec( - "ethos-u55-128", - system_config="Ethos_U55_High_End_Embedded", - memory_mode="Shared_Sram", - extra_flags="--debug-force-regor --output-format=raw", - ) - .set_permute_memory_format( - args.model_name in MODEL_NAME_TO_MODEL.keys() - ) - .set_quantize_io(True) - ) - elif args.target == "ethos-u85-128": - compile_spec = ( - ArmCompileSpecBuilder() - .ethosu_compile_spec( - "ethos-u85-128", - system_config="Ethos_U85_SYS_DRAM_Mid", - memory_mode="Shared_Sram", - extra_flags="--output-format=raw", - ) - .set_permute_memory_format(True) - ) - else: - raise RuntimeError(f"Expected a target in {targets}, found {args.target}") + compile_spec = get_compile_spec(args.target) if args.intermediates is not None: compile_spec.dump_intermediate_artifacts_to(args.intermediates) compile_spec = compile_spec.build()