Skip to content

Commit

Permalink
simplify function for lintrunner
Browse files Browse the repository at this point in the history
Signed-off-by: Rob Elliott <robert.elliott@arm.com>
  • Loading branch information
robell committed Sep 19, 2024
1 parent 71607fa commit 6af34b2
Showing 1 changed file with 21 additions and 12 deletions.
33 changes: 21 additions & 12 deletions examples/arm/aot_arm_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,13 +154,14 @@ def forward(self, x):
]


def get_compile_spec(target: str) -> ArmCompileSpecBuilder:
def get_compile_spec(target: str, intermediates: bool) -> ArmCompileSpecBuilder:
spec_builder = None
if target == "TOSA":
return (
spec_builder = (
ArmCompileSpecBuilder().tosa_compile_spec().set_permute_memory_format(True)
)
elif target == "ethos-u55-128":
return (
spec_builder = (
ArmCompileSpecBuilder()
.ethosu_compile_spec(
"ethos-u55-128",
Expand All @@ -172,7 +173,7 @@ def get_compile_spec(target: str) -> ArmCompileSpecBuilder:
.set_quantize_io(True)
)
elif target == "ethos-u85-128":
return (
spec_builder = (
ArmCompileSpecBuilder()
.ethosu_compile_spec(
"ethos-u85-128",
Expand All @@ -183,8 +184,13 @@ def get_compile_spec(target: str) -> ArmCompileSpecBuilder:
.set_permute_memory_format(True)
)

if intermediates is not None:
spec_builder.dump_intermediate_artifacts_to(args.intermediates)

if __name__ == "__main__":
return spec_builder.build()


def get_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"-m",
Expand Down Expand Up @@ -241,8 +247,12 @@ def get_compile_spec(target: str) -> ArmCompileSpecBuilder:
required=False,
help="Location for outputs, if not the default of cwd.",
)

args = parser.parse_args()
return args


if __name__ == "__main__":
args = get_args()

if args.debug:
logging.basicConfig(level=logging.DEBUG, format=FORMAT, force=True)
Expand Down Expand Up @@ -286,12 +296,11 @@ def get_compile_spec(target: str) -> ArmCompileSpecBuilder:

# As we can target multiple output encodings from ArmBackend, one must
# be specified.
compile_spec = None
if args.delegate is True:
compile_spec = get_compile_spec(args.target)
if args.intermediates is not None:
compile_spec.dump_intermediate_artifacts_to(args.intermediates)
compile_spec = compile_spec.build()
compile_spec = (
get_compile_spec(args.target, args.intermediates)
if args.delegate is True
else None
)

logging.debug(f"Exported graph:\n{edge.exported_program().graph}")
if args.delegate is True:
Expand Down

0 comments on commit 6af34b2

Please sign in to comment.