diff --git a/.ci/scripts/build_llama_android.sh b/.ci/scripts/build_llama_android.sh index 0afe51f0b0..a08fd5499f 100644 --- a/.ci/scripts/build_llama_android.sh +++ b/.ci/scripts/build_llama_android.sh @@ -19,7 +19,6 @@ install_executorch_and_backend_lib() { cmake -DBUCK2="${BUCK2}" \ -DCMAKE_TOOLCHAIN_FILE="${ANDROID_NDK}/build/cmake/android.toolchain.cmake" \ -DANDROID_ABI="${ANDROID_ABI}" \ - -DANDROID_PLATFORM=android-23 \ -DCMAKE_INSTALL_PREFIX=cmake-android-out \ -DCMAKE_BUILD_TYPE=Release \ -DEXECUTORCH_BUILD_EXTENSION_DATA_LOADER=ON \ @@ -41,7 +40,6 @@ build_llama_runner() { cmake -DBUCK2="${BUCK2}" \ -DCMAKE_TOOLCHAIN_FILE="$ANDROID_NDK"/build/cmake/android.toolchain.cmake \ -DANDROID_ABI="${ANDROID_ABI}" \ - -DANDROID_PLATFORM=android-23 \ -DCMAKE_INSTALL_PREFIX=cmake-android-out \ -DCMAKE_BUILD_TYPE=Release -DPYTHON_EXECUTABLE=python \ -DEXECUTORCH_BUILD_XNNPACK=ON \ diff --git a/.ci/scripts/test_llava.sh b/.ci/scripts/test_llava.sh index 8ac87b2302..1057fa8f4a 100644 --- a/.ci/scripts/test_llava.sh +++ b/.ci/scripts/test_llava.sh @@ -56,7 +56,6 @@ cmake_install_executorch_libraries_for_android() { cmake \ -DCMAKE_TOOLCHAIN_FILE=$ANDROID_NDK/build/cmake/android.toolchain.cmake \ -DANDROID_ABI=arm64-v8a \ - -DANDROID_PLATFORM=android-23 \ ${EXECUTORCH_COMMON_CMAKE_ARGS} \ -B${BUILD_DIR} . @@ -93,7 +92,6 @@ cmake_build_llava_runner_for_android() { cmake \ -DCMAKE_TOOLCHAIN_FILE=$ANDROID_NDK/build/cmake/android.toolchain.cmake \ -DANDROID_ABI=arm64-v8a \ - -DANDROID_PLATFORM=android-23 \ ${LLAVA_COMMON_CMAKE_ARGS} \ -DCMAKE_PREFIX_PATH="$python_lib" \ -DLLAVA_RUNNER_NO_TORCH_DUMMY_IMAGE=ON \ diff --git a/.github/scripts/extract_benchmark_results.py b/.github/scripts/extract_benchmark_results.py index 2215a3af3c..113ff2a420 100755 --- a/.github/scripts/extract_benchmark_results.py +++ b/.github/scripts/extract_benchmark_results.py @@ -9,7 +9,6 @@ import logging import os import re -import time import zipfile from argparse import Action, ArgumentParser, Namespace from io import BytesIO @@ -26,12 +25,15 @@ # iOS-related regexes and variables IOS_TEST_SPEC_REGEX = re.compile( - r"Test Case\s+'-\[(?P\w+)\s+(?P\w+)\]'\s+measured\s+\[(?P.+)\]\s+average:\s+(?P[\d\.]+)," + r"Test Case\s+'-\[(?P\w+)\s+(?P[\w\+]+)\]'\s+measured\s+\[(?P.+)\]\s+average:\s+(?P[\d\.]+)," ) IOS_TEST_NAME_REGEX = re.compile( - r"test_(?Pforward|load|generate)_(?P\w+)_pte.*iOS_(?P\w+)_iPhone(?P\w+)" + r"test_(?Pforward|load|generate)_(?P[\w\+]+)_pte.*iOS_(?P\w+)_iPhone(?P\w+)" +) +# The backend name could contain +, i.e. tinyllama_xnnpack+custom+qe_fp32 +IOS_MODEL_NAME_REGEX = re.compile( + r"(?P[^_]+)_(?P[\w\+]+)_(?P\w+)" ) -IOS_MODEL_NAME_REGEX = re.compile(r"(?P[^_]+)_(?P\w+)_(?P\w+)") class ValidateArtifacts(Action): @@ -159,19 +161,8 @@ def initialize_ios_metadata(test_name: str) -> Dict[str, any]: ios_ver = m.group("ios_ver").replace("_", ".") iphone_ver = m.group("iphone_ver").replace("_", ".") - # NB: This looks brittle, but unless we can return iOS benchmark results in JSON - # format by the test, the mapping is needed to match with Android test - if method == "load": - metric = "model_load_time(ms)" - elif method == "forward": - metric = ( - "generate_time(ms)" - if "llama" in model_name - else "avg_inference_latency(ms)" - ) - elif method == "generate": - metric = "token_per_sec" - + # The default backend and quantization dtype if the script couldn't extract + # them from the model name backend = "" quantization = "unknown" @@ -194,8 +185,9 @@ def initialize_ios_metadata(test_name: str) -> Dict[str, any]: "availMem": 0, "totalMem": 0, }, - "metric": metric, + "method": method, # These fields will be populated later by extract_ios_metric + "metric": "", "actualValue": 0, "targetValue": 0, } @@ -210,10 +202,38 @@ def extract_ios_metric( """ Map the metric name from iOS xcresult to the benchmark result """ - if metric_name == "Clock Monotonic Time, s": - # The benchmark value is in ms - benchmark_result["actualValue"] = metric_value * 1000 - elif metric_name == "Tokens Per Second, t/s": + method = benchmark_result.get("method", "") + if not method: + return benchmark_result + + # NB: This looks brittle, but unless we can return iOS benchmark results in JSON + # format by the test, the mapping is needed to match with Android test + if method == "load": + if metric_name == "Clock Monotonic Time, s": + benchmark_result["metric"] = "model_load_time(ms)" + benchmark_result["actualValue"] = metric_value * 1000 + + elif metric_name == "Memory Peak Physical, kB": + # NB: Showing the value in mB is friendlier IMO + benchmark_result["metric"] = "peak_load_mem_usage(mb)" + benchmark_result["actualValue"] = metric_value / 1024 + + elif method == "forward": + if metric_name == "Clock Monotonic Time, s": + benchmark_result["metric"] = ( + "generate_time(ms)" + if "llama" in test_name + else "avg_inference_latency(ms)" + ) + benchmark_result["actualValue"] = metric_value * 1000 + + elif metric_name == "Memory Peak Physical, kB": + # NB: Showing the value in mB is friendlier IMO + benchmark_result["metric"] = "peak_inference_mem_usage(mb)" + benchmark_result["actualValue"] = metric_value / 1024 + + elif method == "generate" and metric_name == "Tokens Per Second, t/s": + benchmark_result["metric"] = "token_per_sec" benchmark_result["actualValue"] = metric_value return benchmark_result @@ -235,6 +255,7 @@ def extract_ios_benchmark_results( with request.urlopen(artifact_s3_url) as data: current_test_name = "" + current_metric_name = "" current_record = {} for line in data.read().decode("utf8").splitlines(): @@ -242,24 +263,25 @@ def extract_ios_benchmark_results( if not s: continue - test_class = s.group("test_class") test_name = s.group("test_name") metric_name = s.group("metric") metric_value = float(s.group("value")) - if test_name != current_test_name: - if current_record: + if test_name != current_test_name or metric_name != current_metric_name: + if current_record and current_record.get("metric", ""): # Save the benchmark result in the same format used by Android benchmark_results.append(current_record.copy()) current_test_name = test_name + current_metric_name = metric_name current_record = initialize_ios_metadata(current_test_name) current_record = extract_ios_metric( current_record, test_name, metric_name, metric_value ) - benchmark_results.append(current_record.copy()) + if current_record and current_record.get("metric", ""): + benchmark_results.append(current_record.copy()) return benchmark_results diff --git a/.github/scripts/propose_ghstack_orig_pr.py b/.github/scripts/propose_ghstack_orig_pr.py new file mode 100644 index 0000000000..a5c715e945 --- /dev/null +++ b/.github/scripts/propose_ghstack_orig_pr.py @@ -0,0 +1,135 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import argparse +import os +import re + +from typing import List + +# Provided by the PyGithub pip package. +from github import Auth, Github +from github.Repository import Repository + + +def parse_args(): + parser = argparse.ArgumentParser( + description=__doc__, formatter_class=argparse.RawDescriptionHelpFormatter + ) + parser.add_argument( + "--repo", + type=str, + help='The github repo to modify: e.g. "pytorch/executorch".', + required=True, + ) + parser.add_argument( + "--pr", + type=int, + help="Number of the PR in the stack to check and create corresponding PR", + required=True, + ) + return parser.parse_args() + + +def extract_stack_from_body(pr_body: str) -> List[int]: + """Extracts a list of PR numbers from a ghexport-generated PR body. + + The base of the stack is in index 0. + """ + + # Expected format. The `__->__` could appear on any line. Stop parsing + # after the blank line. This would return [1, 2, 3]. + """ + Stack from [ghstack](https://github.com/ezyang/ghstack) (oldest at bottom): + * #3 + * __->__ #2 + * #1 + + + """ + + prs = [] + ghstack_begin = ( + "Stack from [ghstack](https://github.com/ezyang/ghstack) (oldest at bottom):" + ) + ghstack_begin_seen = False + for line in pr_body.splitlines(): + if ghstack_begin in line: + ghstack_begin_seen = True + if not ghstack_begin_seen: + continue + match = re.match(r"\*(?:.*?)? #(\d+)", line) + if match: + # It's a bullet followed by an integer. + prs.append(int(match.group(1))) + return list(reversed(prs)) + + +def get_pr_stack_from_number(pr_number: int, repo: Repository) -> List[int]: + pr_stack = extract_stack_from_body(repo.get_pull(pr_number).body) + + if not pr_stack: + raise Exception( + f"Could not find PR stack in body of #{pr_number}. " + + "Please make sure that the PR was created with ghstack." + ) + + return pr_stack + + +def create_prs_for_orig_branch(pr_stack: List[int], repo: Repository): + # For the first PR, we want to merge to `main` branch, and we will update + # as we go through the stack + orig_branch_merge_base = "main" + for i in range(len(pr_stack)): + pr = repo.get_pull(pr_stack[i]) + if not pr.is_merged(): + print("The PR (and stack above) is not merged yet, skipping") + return + # Check for invariant: For the current PR, it must be gh/user/x/base <- gh/user/x/head + assert pr.base.ref.replace("base", "head") == pr.head.ref + # The PR we want to create is then "branch_to_merge" <- gh/user/x/orig + # gh/user/x/orig is the clean diff between gh/user/x/base <- gh/user/x/head + orig_branch_merge_head = pr.base.ref.replace("base", "orig") + bot_metadata = f"""This PR was created by the merge bot to help merge the original PR into the main branch. +ghstack PR number: https://github.com/pytorch/executorch/pull/{pr.number} +^ Please use this as the source of truth for the PR details, comments, and reviews +ghstack PR base: https://github.com/pytorch/executorch/tree/{pr.base.ref} +ghstack PR head: https://github.com/pytorch/executorch/tree/{pr.head.ref} +Merge bot PR base: https://github.com/pytorch/executorch/tree/{orig_branch_merge_base} +Merge bot PR head: https://github.com/pytorch/executorch/tree/{orig_branch_merge_head}""" + + existing_orig_pr = repo.get_pulls( + head="pytorch:" + orig_branch_merge_head, + base=orig_branch_merge_base, + state="open", + ) + if existing_orig_pr.totalCount > 0: + print( + f"PR for {orig_branch_merge_head} already exists {existing_orig_pr[0]}" + ) + # We don't need to create/edit because the head PR is merged and orig is finalized. + else: + repo.create_pull( + base=orig_branch_merge_base, + head=orig_branch_merge_head, + title=pr.title, + body=bot_metadata, + ) + # Advance the base for the next PR + orig_branch_merge_base = orig_branch_merge_head + + +def main(): + args = parse_args() + + with Github(auth=Auth.Token(os.environ["GITHUB_TOKEN"])) as gh: + repo = gh.get_repo(args.repo) + create_prs_for_orig_branch(get_pr_stack_from_number(args.pr, repo), repo) + + +if __name__ == "__main__": + main() diff --git a/.github/workflows/apple-perf.yml b/.github/workflows/apple-perf.yml index 7de308b1a6..f693df4779 100644 --- a/.github/workflows/apple-perf.yml +++ b/.github/workflows/apple-perf.yml @@ -76,7 +76,7 @@ jobs: # on-demand and periodic benchmarking. CRON_DEFAULT_MODELS: "stories110M,mv3,mv2,ic4,ic3,resnet50,edsr,mobilebert,w2l" CRON_DEFAULT_DEVICES: "apple_iphone_15" - CRON_DEFAULT_DELEGATES: "xnnpack,coreml,mps" + CRON_DEFAULT_DELEGATES: "nnpack,coreml,mps" run: | set -ex MODELS="${{ inputs.models }}" diff --git a/.github/workflows/ghstack_land.yml b/.github/workflows/ghstack_land.yml new file mode 100644 index 0000000000..2c91a1aa40 --- /dev/null +++ b/.github/workflows/ghstack_land.yml @@ -0,0 +1,40 @@ +name: Propose to merge ghstack orig PRs to main +on: + pull_request: + types: [closed] + branches: + - 'gh/cccclai/[0-9]+/base' + - 'gh/dbort/[0-9]+/base' + - 'gh/guangy10/[0-9]+/base' + - 'gh/helunwencser/[0-9]+/base' + - 'gh/jorgep31415/[0-9]+/base' + - 'gh/kimishpatel/[0-9]+/base' + - 'gh/kirklandsign/[0-9]+/base' + - 'gh/larryliu0820/[0-9]+/base' + - 'gh/manuelcandales/[0-9]+/base' + - 'gh/mcr229/[0-9]+/base' + - 'gh/swolchok/[0-9]+/base' + - 'gh/SS-JIA/[0-9]+/base' + +jobs: + ghstack_merge_to_main: + name: Try to create a PR with ghstack /orig branch + runs-on: ubuntu-22.04 + environment: cherry-pick-bot + steps: + - uses: actions/checkout@v3 + with: + fetch-depth: '0' + - uses: actions/setup-python@v4 + with: + python-version: '3.10' + - name: Try to merge PR to main + run: | + pip install pygithub + + PR_NUMBER=$(echo "$GITHUB_REF" | grep -oE '[0-9]+') + + python .github/scripts/propose_ghstack_orig_pr.py --pr $PR_NUMBER --repo pytorch/executorch + env: + GITHUB_TOKEN: ${{ secrets.GH_PYTORCHBOT_CHERRY_PICK_TOKEN }} + GITHUB_REF: ${{ github.ref }} diff --git a/CMakeLists.txt b/CMakeLists.txt index 030877eb9f..156fb24e6b 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -372,6 +372,9 @@ endif() # Detect if an Android toolchain is set. if(CMAKE_TOOLCHAIN_FILE MATCHES ".*android\.toolchain\.cmake$") set(CMAKE_TOOLCHAIN_ANDROID ON) +if(NOT ANDROID_PLATFORM) + set(ANDROID_PLATFORM android-30) +endif() else() set(CMAKE_TOOLCHAIN_ANDROID OFF) endif() diff --git a/backends/arm/_passes/annotate_channels_last_dim_order_pass.py b/backends/arm/_passes/annotate_channels_last_dim_order_pass.py index 222c0a7cb3..b4365bf75e 100644 --- a/backends/arm/_passes/annotate_channels_last_dim_order_pass.py +++ b/backends/arm/_passes/annotate_channels_last_dim_order_pass.py @@ -9,6 +9,7 @@ from typing import cast import torch +from executorch.backends.arm._passes.arm_pass_utils import get_first_fake_tensor from executorch.backends.arm.tosa_quant_utils import dq_op from executorch.backends.arm.tosa_utils import is_consumer_node_depthwise_conv2d from executorch.exir.pass_base import ExportPass, PassResult @@ -52,12 +53,7 @@ def call(self, graph_module: torch.fx.GraphModule): NHWC_Order = (0, 2, 3, 1) HWCM_Order = (2, 3, 0, 1) for node in graph_module.graph.nodes: - if isinstance( - node.meta["val"], (tuple, torch.fx.immutable_collections.immutable_list) - ): - node_data = node.meta["val"][0].data - else: - node_data = node.meta["val"].data + node_data = get_first_fake_tensor(node).data if len(node_data.shape) == 4: dim_order = NHWC_Order diff --git a/backends/arm/_passes/arm_pass_manager.py b/backends/arm/_passes/arm_pass_manager.py index c4e806a842..3e061dbfeb 100644 --- a/backends/arm/_passes/arm_pass_manager.py +++ b/backends/arm/_passes/arm_pass_manager.py @@ -22,6 +22,7 @@ from executorch.backends.arm._passes.insert_squeeze_after_sum_pass import ( InsertSqueezeAfterSumPass, ) +from executorch.backends.arm._passes.match_arg_ranks_pass import MatchArgRanksPass from executorch.backends.arm._passes.meandim_to_averagepool_pass import ( ConvertMeanDimToAveragePool, ) @@ -30,6 +31,9 @@ ScalarsToAttributePass, ) from executorch.backends.arm._passes.size_adjust_conv2d_pass import SizeAdjustConv2DPass +from executorch.backends.arm._passes.unsqueeze_scalar_placeholders_pass import ( + UnsqueezeScalarPlaceholdersPass, +) from executorch.exir import ExportedProgram from executorch.exir.backend.compile_spec_schema import CompileSpec from executorch.exir.pass_manager import PassManager @@ -45,10 +49,12 @@ def transform_to_backend_pipeline( ): """Apply passes before transforming program to backend""" self.add_pass(CastInt64ToInt32Pass(exported_program)) + self.add_pass(UnsqueezeScalarPlaceholdersPass(exported_program)) self.add_pass(SizeAdjustConv2DPass()) self.add_pass(RemoveClonePass()) self.add_pass(ConvertExpandCopyToRepeatPass()) self.add_pass(ConvertMeanDimToAveragePool()) + self.add_pass(MatchArgRanksPass(exported_program)) self.add_pass(DecomposeDivPass()) self.add_pass(InsertSqueezeAfterSumPass()) self.add_pass(ConvertSplitToSlicePass()) @@ -61,6 +67,6 @@ def transform_to_backend_pipeline( return self._transform(exported_program.graph_module) def transform_for_annotation_pipeline(self, graph_module: torch.fx.GraphModule): - self.add_pass(DecomposeDivPass()) self.add_pass(ScalarsToAttributePass()) + self.add_pass(DecomposeDivPass()) return self._transform(graph_module) diff --git a/backends/arm/_passes/arm_pass_utils.py b/backends/arm/_passes/arm_pass_utils.py index 34704d2ced..0e74701ab6 100644 --- a/backends/arm/_passes/arm_pass_utils.py +++ b/backends/arm/_passes/arm_pass_utils.py @@ -7,9 +7,11 @@ from typing import Optional import torch +import torch.fx from executorch.exir.dialects._ops import ops as exir_ops from torch._ops import OpOverload +from torch._subclasses.fake_tensor import FakeTensor def create_node( @@ -64,3 +66,21 @@ def insert_q_dq_pair( # node's first use q.args = (anchor,) + q_params return dq + + +def get_first_fake_tensor(node: torch.fx.Node) -> FakeTensor: + """ + Returns a FakeTensor from the meta field of 'node'. + If the node contains many fake tensors, return the first one. + """ + if isinstance( + node.meta["val"], (tuple, torch.fx.immutable_collections.immutable_list) + ): + fake_tensor = node.meta["val"][0] + else: + fake_tensor = node.meta["val"] + + assert isinstance( + fake_tensor, FakeTensor + ), f'Found {fake_tensor} in meta["val"] of {node}, expected to find FakeTensor.' + return fake_tensor diff --git a/backends/arm/_passes/decompose_div_pass.py b/backends/arm/_passes/decompose_div_pass.py index 13ee8d8dff..5cdc79c1c3 100644 --- a/backends/arm/_passes/decompose_div_pass.py +++ b/backends/arm/_passes/decompose_div_pass.py @@ -8,15 +8,18 @@ from executorch.exir.dialects._ops import ops as exir_ops from executorch.exir.pass_base import ExportPass +edge_div_ops = (exir_ops.edge.aten.div.Tensor,) +aten_div_ops = (torch.ops.aten.div.Tensor, torch.ops.aten.div_.Tensor) + def get_div_decomposition(op) -> tuple: """ Returns the the (reciprocal_op, mul_op), where the ops depends on if the div op is in exir_ops torch.ops.aten. """ - if op == exir_ops.edge.aten.div.Tensor: + if op in edge_div_ops: return (exir_ops.edge.aten.reciprocal.default, exir_ops.edge.aten.mul.Tensor) - if op == torch.ops.aten.div.Tensor: + if op in aten_div_ops: return (torch.ops.aten.reciprocal.default, torch.ops.aten.mul.Tensor) raise RuntimeError(f"Can't get div decomposition for op {op}") @@ -33,7 +36,7 @@ class DecomposeDivPass(ExportPass): """ def call_operator(self, op, args, kwargs, meta): - if op not in (exir_ops.edge.aten.div.Tensor, torch.ops.aten.div.Tensor): + if op not in (edge_div_ops + aten_div_ops): return super().call_operator(op, args, kwargs, meta) reciprocal_op, mul_op = get_div_decomposition(op) diff --git a/backends/arm/_passes/match_arg_ranks_pass.py b/backends/arm/_passes/match_arg_ranks_pass.py new file mode 100644 index 0000000000..e0cbcf294f --- /dev/null +++ b/backends/arm/_passes/match_arg_ranks_pass.py @@ -0,0 +1,126 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# Copyright 2024 Arm Limited and/or its affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import cast + +from executorch.backends.arm._passes.arm_pass_utils import ( + create_node, + get_first_fake_tensor, +) + +from executorch.exir.dialects._ops import ops as exir_ops + +from executorch.exir.pass_base import ExportPass, PassResult +from torch.fx import GraphModule, Node + + +class MatchArgRanksPass(ExportPass): + """ + For ops in 'targeted_ops', make sure that the inputs share the same rank. + New dimensions are inserted at from the beginning of the + """ + + def __init__(self, exported_program): + super().__init__() + self.exported_program = exported_program + + targeted_ops = [ + exir_ops.edge.aten.add.Tensor, + exir_ops.edge.aten.sub.Tensor, + exir_ops.edge.aten.mul.Tensor, + exir_ops.edge.aten.div.Tensor, + ] + + def _match_op_rank(self, graph_module, node, arg, max_rank): + """ + In graph_module, insert a view between arg and node to make the + rank of arg match the other args to node. + """ + shape = get_first_fake_tensor(arg).shape + rank = len(shape) + new_shape = list([1] * (max_rank - rank) + list(shape)) + with graph_module.graph.inserting_before(node): + view = create_node( + graph_module.graph, + exir_ops.edge.aten.view_copy.default, + args=(arg, new_shape), + kwargs={}, + ) + node.replace_input_with(arg, view) + + def _match_buffer_rank(self, arg, max_rank): + """ + Change arg's fake tensor meta to match max_rank if: + - arg is found in inputs_to_buffers or inputs_to_parameters. + """ + fake_tensor = get_first_fake_tensor(arg) + shape = fake_tensor.shape + rank = len(shape) + new_shape = list([1] * (max_rank - rank) + list(shape)) + + buffer_name = None + if arg.name in self.exported_program.graph_signature.inputs_to_buffers: + buffer_name = self.exported_program.graph_signature.inputs_to_buffers[ + arg.name + ] + elif arg.name in self.exported_program.graph_signature.inputs_to_parameters: + buffer_name = self.exported_program.graph_signature.inputs_to_parameters[ + arg.name + ] + if buffer_name: + new_tensor = self.exported_program.state_dict[buffer_name].reshape( + new_shape + ) + self.exported_program.state_dict[buffer_name] = new_tensor + arg.meta["val"] = fake_tensor.fake_mode.from_tensor( + new_tensor, static_shapes=True + ) + + def call(self, graph_module: GraphModule) -> PassResult: + for node in graph_module.graph.nodes: + node = cast(Node, node) + + if node.op != "call_function" or node.target not in self.targeted_ops: + continue + + # Calculate max rank of all inputs to node + max_rank = 1 + for arg in node.args: + if isinstance(arg, Node): + shape = get_first_fake_tensor(arg).shape + max_rank = max(max_rank, len(shape)) + + # Adjust output shape of args if needed. + for arg in node.args: + if not isinstance(arg, Node): + continue + shape = get_first_fake_tensor(arg).shape + rank = len(shape) + if rank == max_rank: + continue + + # If the argument is call_function, match shape by inserting view node. + if arg.op == "call_function": + self._match_op_rank(graph_module, node, arg, max_rank) + else: + # If the argument is a buffer or parameter, adjust shape by changing the fake tensor meta. + self._match_buffer_rank(arg, max_rank) + + graph_module.recompile() + graph_module = super().call(graph_module).graph_module + return PassResult(graph_module, True) + + def ensures(self, graph_module): + for node in graph_module.graph.nodes: + if node.op != "call_function" or node.target not in self.targeted_ops: + continue + arg0_rank = node.args[0].meta["val"].dim() + arg1_rank = node.args[1].meta["val"].dim() + if arg0_rank != arg1_rank: + raise ValueError( + "Arguments of arithmetic operators need to have the same rank!" + ) diff --git a/backends/arm/_passes/scalars_to_attribute_pass.py b/backends/arm/_passes/scalars_to_attribute_pass.py index e9e547b9c9..f1c3297165 100644 --- a/backends/arm/_passes/scalars_to_attribute_pass.py +++ b/backends/arm/_passes/scalars_to_attribute_pass.py @@ -7,7 +7,7 @@ from typing import cast, Union import torch -from executorch.backends.arm.tosa_mapping import extract_tensor_meta +from executorch.backends.arm._passes.arm_pass_utils import get_first_fake_tensor from executorch.exir.pass_base import ExportPass, PassResult from torch.ao.quantization.fx.utils import get_new_attr_name_with_prefix @@ -22,10 +22,14 @@ class ScalarsToAttributePass(ExportPass): targeted_ops = [ torch.ops.aten.add.Tensor, + torch.ops.aten.add_.Tensor, torch.ops.aten.sub.Tensor, torch.ops.aten.sub_.Tensor, + torch.ops.aten.rsub.Scalar, torch.ops.aten.mul.Tensor, + torch.ops.aten.mul_.Tensor, torch.ops.aten.div.Tensor, + torch.ops.aten.div_.Tensor, ] def call(self, graph_module: GraphModule) -> PassResult: @@ -37,7 +41,7 @@ def call(self, graph_module: GraphModule) -> PassResult: biggest_rank = 1 for arg in n.args: if isinstance(arg, Node): - _, shape, _ = extract_tensor_meta(arg.meta) + shape = get_first_fake_tensor(arg).shape biggest_rank = max(biggest_rank, len(shape)) new_args = [] diff --git a/backends/arm/_passes/unsqueeze_scalar_placeholders_pass.py b/backends/arm/_passes/unsqueeze_scalar_placeholders_pass.py new file mode 100644 index 0000000000..ad9844b526 --- /dev/null +++ b/backends/arm/_passes/unsqueeze_scalar_placeholders_pass.py @@ -0,0 +1,53 @@ +# Copyright 2024 Arm Limited and/or its affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import torch +from executorch.exir.pass_base import ExportPass, PassResult + + +class UnsqueezeScalarPlaceholdersPass(ExportPass): + """ + Placeholders that have node.meta["val"].shape = () cause issues later in the lowering. + This pass unsqueezes the placeholders to make sure shape is at least (1,). + """ + + def __init__(self, exported_program): + self.exported_program = exported_program + super().__init__() + + def call(self, graph_module: torch.fx.GraphModule): + for node in graph_module.graph.nodes: + if node.op != "placeholder": + continue + rank = node.meta["val"].dim() + if rank == 0: + if not ( + node.name in self.exported_program.graph_signature.inputs_to_buffers + or node.name + in self.exported_program.graph_signature.inputs_to_parameters + ): + continue + tensor = self.exported_program.state_dict[node.name] + if tensor.dim() == 0: + self.exported_program.state_dict[node.name] = tensor.unsqueeze(0) + node.meta["val"] = node.meta["val"].fake_mode.from_tensor( + tensor.unsqueeze(0), static_shapes=True + ) + else: + node.meta["val"] = node.meta["val"].fake_mode.from_tensor( + tensor, static_shapes=True + ) + + graph_module.recompile() + graph_module = super().call(graph_module).graph_module + return PassResult(graph_module, True) + + def ensures(self, graph_module: torch.fx.GraphModule): + for node in graph_module.graph.nodes: + if node.op == "placeholder": + rank = node.meta["val"].dim() + if rank == 0: + raise ValueError("Placeholders of rank 0 are not supported!") diff --git a/backends/arm/quantizer/quantization_annotation/mul_annotator.py b/backends/arm/quantizer/quantization_annotation/mul_annotator.py index 5df697f4b1..47190d380e 100644 --- a/backends/arm/quantizer/quantization_annotation/mul_annotator.py +++ b/backends/arm/quantizer/quantization_annotation/mul_annotator.py @@ -24,7 +24,7 @@ def _annotate_mul( annotated_partitions = [] for node in gm.graph.nodes: - if node.target not in (torch.ops.aten.mul.Tensor,): + if node.target not in (torch.ops.aten.mul.Tensor, torch.ops.aten.mul_.Tensor): continue mul_node = node annotated_partitions.append([mul_node]) diff --git a/backends/arm/quantizer/quantization_annotation/sub_annotator.py b/backends/arm/quantizer/quantization_annotation/sub_annotator.py index 92f1808d02..437f3e22e7 100644 --- a/backends/arm/quantizer/quantization_annotation/sub_annotator.py +++ b/backends/arm/quantizer/quantization_annotation/sub_annotator.py @@ -6,8 +6,6 @@ # pyre-unsafe -import itertools -import operator from typing import Callable, List, Optional import torch @@ -16,7 +14,6 @@ from executorch.backends.arm.quantizer.quantization_config import QuantizationConfig from torch.ao.quantization.quantizer import QuantizationAnnotation from torch.fx import GraphModule, Node -from torch.fx.passes.utils.source_matcher_utils import get_source_partitions @register_annotator("sub") @@ -25,14 +22,12 @@ def _annotate_sub( quantization_config: QuantizationConfig, filter_fn: Optional[Callable[[Node], bool]] = None, ) -> Optional[List[List[Node]]]: - sub_partitions = get_source_partitions( - gm.graph, [operator.sub, torch.sub, operator.isub], filter_fn - ) - sub_partitions = list(itertools.chain.from_iterable(sub_partitions.values())) annotated_partitions = [] - for sub_partition in sub_partitions: - annotated_partitions.append(sub_partition.nodes) - sub_node = sub_partition.output_nodes[0] + for node in gm.graph.nodes: + if node.target not in (torch.ops.aten.sub.Tensor, torch.ops.aten.sub_.Tensor): + continue + annotated_partitions.append(node) + sub_node = node if arm_quantizer_utils.is_annotated(sub_node): continue diff --git a/backends/arm/test/misc/test_lifted_tensor.py b/backends/arm/test/misc/test_lifted_tensor.py index 90aa7e2950..29b2887431 100644 --- a/backends/arm/test/misc/test_lifted_tensor.py +++ b/backends/arm/test/misc/test_lifted_tensor.py @@ -3,40 +3,119 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +import operator import unittest +from typing import Union import torch from executorch.backends.arm.test import common from executorch.backends.arm.test.tester.arm_tester import ArmTester +from parameterized import parameterized class LiftedTensor(torch.nn.Module): - def __init__(self): + test_data = [ + # (operator, test_data, length) + (operator.add, (torch.randn(2, 2), 2)), + (operator.truediv, (torch.ones(2, 2), 2)), + (operator.mul, (torch.randn(2, 2), 2)), + (operator.sub, (torch.rand(2, 2), 2)), + ] + + def __init__(self, op: callable): super().__init__() + self.op = op self.lifted_tensor = torch.Tensor([[1, 2], [3, 4]]) def forward(self, x: torch.Tensor, length) -> torch.Tensor: sliced = self.lifted_tensor[:, :length] - return sliced + x + return self.op(sliced, x) + + +class LiftedScalarTensor(torch.nn.Module): + test_data = [ + # (operator, test_data) + (operator.add, (torch.randn(2, 2),), 1.0), + (operator.truediv, (torch.randn(4, 2),), 1.0), + (operator.mul, (torch.randn(1, 2),), 2.0), + (operator.sub, (torch.randn(3),), 1.0), + ] + + def __init__(self, op: callable, arg1: Union[int, float, torch.tensor]): + super().__init__() + self.op = op + self.arg1 = arg1 + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.op(x, self.arg1) class TestLiftedTensor(unittest.TestCase): """Tests the ArmPartitioner with a placeholder of type lifted tensor.""" - def test_partition_lifted_tensor(self): + @parameterized.expand(LiftedTensor.test_data) + def test_partition_lifted_tensor_tosa_MI(self, op, data): tester = ( ArmTester( - LiftedTensor(), - example_inputs=(torch.ones(2, 2), 2), + LiftedTensor(op), + example_inputs=data, compile_spec=common.get_tosa_compile_spec(), ) .export() .to_edge() - .dump_artifact() ) signature = tester.get_artifact().exported_program().graph_signature assert len(signature.lifted_tensor_constants) > 0 tester.partition() tester.to_executorch() - tester.run_method_and_compare_outputs((torch.ones(2, 2), 2)) + tester.run_method_and_compare_outputs(data) + + @parameterized.expand(LiftedTensor.test_data) + def test_partition_lifted_tensor_tosa_BI(self, op, data): + tester = ( + ArmTester( + LiftedTensor(op), + example_inputs=data, + compile_spec=common.get_tosa_compile_spec(), + ) + .quantize() + .export() + .to_edge() + ) + signature = tester.get_artifact().exported_program().graph_signature + assert len(signature.lifted_tensor_constants) == 0 + tester.partition() + tester.to_executorch() + tester.run_method_and_compare_outputs(data) + + @parameterized.expand(LiftedScalarTensor.test_data) + def test_partition_lifted_scalar_tensor_tosa_MI(self, op, data, arg1): + ( + ArmTester( + LiftedScalarTensor(op, arg1), + example_inputs=(data), + compile_spec=common.get_tosa_compile_spec(), + ) + .export() + .to_edge() + .partition() + .to_executorch() + .run_method_and_compare_outputs(data) + ) + + @parameterized.expand(LiftedScalarTensor.test_data) + def test_partition_lifted_scalar_tensor_tosa_BI(self, op, data, arg1): + ( + ArmTester( + LiftedScalarTensor(op, arg1), + example_inputs=(data), + compile_spec=common.get_tosa_compile_spec(), + ) + .quantize() + .export() + .to_edge() + .partition() + .to_executorch() + .run_method_and_compare_outputs(data) + ) diff --git a/backends/arm/test/ops/test_scalars.py b/backends/arm/test/ops/test_scalars.py new file mode 100644 index 0000000000..154ca82022 --- /dev/null +++ b/backends/arm/test/ops/test_scalars.py @@ -0,0 +1,162 @@ +import unittest + +import torch + +from executorch.backends.arm.test import common +from executorch.backends.arm.test.tester.arm_tester import ArmTester +from parameterized import parameterized + +""" +Summary of non-working cases. +MI: + Any case with int scalar: A to_copy is inserted to cast the value which we don't partition. + This makes the constant end up outside our partition and the input to the delegate becomes + a to_copy placeholder. In ArmTester, the placeholder is then interpreted as an input. + Potential fix: partition int -> float to_copy-ops in ArmBackend. + # MLETORCH-407 + Op(scalar, tensor): + One issue is that lift_constant_tensor_pass looks for a fake_tensor in the meta of the first + node which does not work the first node is a scalar. + Fixing that, the lowering fails since edge_program.graph_signatures.inputs_to_buffers is changed from + {"_lifted_tensor_constant0":"_lifted_tensor_constant0"} to {"x":"_lifted_tensor_constant0"} + somewhere in _transform in the to_edge step. This makes ArmPartitioner miss tagging the + data in tag_constant_data. + # MLETORCH-408 + +BI: + sub(Scalar, Tensor) becomes rsub, which either fails since the scalar does not become an attribute + in scalars_to_attribute_pass, or, if added to targeted_ops in that pass, fails since rsub expects a + Scalar. + Potential fix: Create pass to convert rsub.Scalar to sub.Tensor +""" + + +class TestScalars(unittest.TestCase): + """Tests various scalar cases for for""" + + class Add(torch.nn.Module): + def forward(self, x, y): + return x + y + + class Sub(torch.nn.Module): + def forward(self, x, y): + return x - y + + class Div(torch.nn.Module): + def forward(self, x, y): + return x / y + + class Mul(torch.nn.Module): + def forward(self, x, y): + return x * y + + class AddInplace(torch.nn.Module): + def forward(self, x, y): + x += y + return x + + class SubInplace(torch.nn.Module): + def forward(self, x, y): + x -= y + return x + + class DivInplace(torch.nn.Module): + def forward(self, x, y): + x /= y + return x + + class MulInplace(torch.nn.Module): + def forward(self, x, y): + x *= y + return x + + class AddConst(torch.nn.Module): + def forward(self, x): + x = 1.0 + x + return x + + # Inplace ops end with '_' (from aten naming) + ops = [ + ("Add", Add()), + ("Sub", Sub()), + ("Mul", Mul()), + ("Div", Div()), + ("Add_", AddInplace()), + ("Sub_", SubInplace()), + ("Mul_", MulInplace()), + ("Div_", DivInplace()), + ] + + const_ops = [("Add", AddConst())] + + dtypes = [("int", 3), ("float", 3.0)] + sizes = [("r1", (1)), ("r4", (2, 4, 5, 3))] + + # Create combinations of tests + tensor_scalar_tests = [] + for op in ops: + for dtype in dtypes: + for size in sizes: + test_name = f"{op[0]}_{dtype[0]}_{size[0]}" + tensor = torch.rand(size[1]) + scalar = dtype[1] + tensor_scalar_tests.append((test_name + "_ts", op[1], tensor, scalar)) + + # Don't add (scalar, tensor) test case for inplace ops. + if op[0][-1] == "_": + continue + + # sub(scalar, tensor) does not work in any case. + if op[0][0:3] == "Sub": + continue + tensor_scalar_tests.append((test_name + "_st", op[1], scalar, tensor)) + + tensor_const_tests = [] + for op in const_ops: + for size in sizes: + test_name = f"{op[0]}_{size[0]}" + tensor = torch.rand(size[1]) + tensor_const_tests.append((test_name, op[1], tensor)) + + def _test_add_tosa_MI_pipeline(self, module: torch.nn.Module, test_data: tuple): + ( + ArmTester( + module, + example_inputs=test_data, + compile_spec=common.get_tosa_compile_spec(), + ) + .export() + .to_edge() + .partition() + .to_executorch() + .run_method_and_compare_outputs(inputs=test_data) + ) + + def _test_add_tosa_BI_pipeline(self, module: torch.nn.Module, test_data: tuple): + ( + ArmTester( + module, + example_inputs=test_data, + compile_spec=common.get_tosa_compile_spec(), + ) + .quantize() + .export() + .to_edge() + .partition() + .to_executorch() + .run_method_and_compare_outputs(inputs=test_data) + ) + + # Most MI tests fail, just show one working for now. + @parameterized.expand((tensor_scalar_tests[6],)) + def test_MI(self, test_name: str, op: torch.nn.Module, x, y): + self._test_add_tosa_MI_pipeline(op, (x, y)) + + # op(Scalar float, tensor) works if the scalar is constant. + @parameterized.expand(tensor_const_tests) + def test_MI_const(self, test_name: str, op: torch.nn.Module, x): + self._test_add_tosa_MI_pipeline(op, (x,)) + + @parameterized.expand(tensor_scalar_tests) + def test_BI(self, test_name: str, op: torch.nn.Module, x, y): + self._test_add_tosa_BI_pipeline(op, (x, y)) diff --git a/backends/arm/test/tester/arm_tester.py b/backends/arm/test/tester/arm_tester.py index 053ddc3a8e..59d326109d 100644 --- a/backends/arm/test/tester/arm_tester.py +++ b/backends/arm/test/tester/arm_tester.py @@ -293,9 +293,9 @@ def run_method_and_compare_outputs( test_input: list[torch.Tensor] = [] for arg in reference_input: if isinstance(arg, torch.Tensor): - test_input.append(arg) + test_input.append(arg.clone()) if isinstance(arg, tuple) and isinstance(arg[0], torch.Tensor): - test_input.extend(list(arg)) + test_input.extend([tensor.clone() for tensor in arg]) if ( is_nhwc diff --git a/backends/cadence/aot/functions_hifi.yaml b/backends/cadence/aot/functions_hifi.yaml index 729db66850..5c25d89946 100644 --- a/backends/cadence/aot/functions_hifi.yaml +++ b/backends/cadence/aot/functions_hifi.yaml @@ -107,21 +107,21 @@ variants: function kernels: - arg_meta: null - kernel_name: impl::HiFi::quantize_per_tensor_out + kernel_name: cadence::impl::HiFi::quantize_per_tensor_out - func: cadence::dequantize_per_tensor.out(Tensor input, float scale, int zero_point, int quant_min, int quant_max, ScalarType dtype, *, Tensor(a!) out) -> Tensor(a!) variants: function kernels: - arg_meta: null - kernel_name: impl::HiFi::dequantize_per_tensor_out + kernel_name: cadence::impl::HiFi::dequantize_per_tensor_out - func: cadence::quantized_layer_norm.out(Tensor input, Tensor in_scale, Tensor in_zero_point, int[] normalized_shape, Tensor weight, Tensor bias, float eps, float output_scale, int output_zero_point, *, Tensor(a!) out) -> Tensor(a!) kernels: - arg_meta: null - kernel_name: impl::HiFi::quantized_layer_norm_out + kernel_name: cadence::impl::HiFi::quantized_layer_norm_out - func: cadence::quantized_linear.out(Tensor src, Tensor weight, Tensor bias, int src_zero_point, Tensor weight_zero_point, Tensor out_multiplier, Tensor out_shift, int out_zero_point, Tensor? offset, *, Tensor(a!) out) -> Tensor(a!) kernels: - arg_meta: null - kernel_name: impl::HiFi::quantized_linear_out + kernel_name: cadence::impl::HiFi::quantized_linear_out diff --git a/backends/cadence/hifi/kernels/kernels.cpp b/backends/cadence/hifi/kernels/kernels.cpp index 4d9183e4cc..10e5fb176e 100644 --- a/backends/cadence/hifi/kernels/kernels.cpp +++ b/backends/cadence/hifi/kernels/kernels.cpp @@ -10,6 +10,7 @@ #include #include +namespace cadence { namespace impl { namespace HiFi { namespace kernels { @@ -231,3 +232,4 @@ typed_requantize_vec(uint8_t, int8_t); }; // namespace kernels }; // namespace HiFi }; // namespace impl +}; // namespace cadence diff --git a/backends/cadence/hifi/kernels/kernels.h b/backends/cadence/hifi/kernels/kernels.h index b565982461..d27e8051f5 100644 --- a/backends/cadence/hifi/kernels/kernels.h +++ b/backends/cadence/hifi/kernels/kernels.h @@ -12,6 +12,7 @@ #include #include +namespace cadence { namespace impl { namespace HiFi { namespace kernels { @@ -63,3 +64,4 @@ void dequantize( }; // namespace kernels }; // namespace HiFi }; // namespace impl +}; // namespace cadence diff --git a/backends/cadence/hifi/operators/dequantize_per_tensor.cpp b/backends/cadence/hifi/operators/dequantize_per_tensor.cpp index 935ff8a501..2a548fb231 100644 --- a/backends/cadence/hifi/operators/dequantize_per_tensor.cpp +++ b/backends/cadence/hifi/operators/dequantize_per_tensor.cpp @@ -10,6 +10,7 @@ #include #include +namespace cadence { namespace impl { namespace HiFi { namespace native { @@ -50,3 +51,4 @@ void dequantize_per_tensor_out( }; // namespace native }; // namespace HiFi }; // namespace impl +}; // namespace cadence diff --git a/backends/cadence/hifi/operators/quantize_per_tensor.cpp b/backends/cadence/hifi/operators/quantize_per_tensor.cpp index f17c865392..9b2034973f 100644 --- a/backends/cadence/hifi/operators/quantize_per_tensor.cpp +++ b/backends/cadence/hifi/operators/quantize_per_tensor.cpp @@ -10,6 +10,7 @@ #include #include +namespace cadence { namespace impl { namespace HiFi { namespace native { @@ -21,28 +22,32 @@ using executorch::runtime::KernelRuntimeContext; // Quantize the input tensor (PT2 version). Note that quant_ are not // used in any computation. void quantize_per_tensor_out( - KernelRuntimeContext& context, + KernelRuntimeContext& ctx, const Tensor& input, double scale, int64_t zero_point, - int64_t quant_min, - int64_t quant_max, + __ET_UNUSED int64_t quant_min, + __ET_UNUSED int64_t quant_max, ScalarType dtype, Tensor& out) { const float* input_data = input.const_data_ptr(); - size_t numel = out.numel(); + const size_t numel = out.numel(); if (out.scalar_type() == ScalarType::Byte) { uint8_t* out_data = out.mutable_data_ptr(); - impl::HiFi::kernels::quantize( + cadence::impl::HiFi::kernels::quantize( out_data, input_data, 1. / scale, zero_point, numel); } else if (out.scalar_type() == ScalarType::Char) { int8_t* out_data = out.mutable_data_ptr(); xa_nn_elm_quantize_f32_asym8s( out_data, input_data, scale, zero_point, numel); + } else if (out.scalar_type() == ScalarType::Short) { + int16_t* out_data = out.mutable_data_ptr(); + cadence::impl::HiFi::kernels::quantize( + out_data, input_data, 1. / scale, zero_point, numel); } else if (out.scalar_type() == ScalarType::Int) { int32_t* out_data = out.mutable_data_ptr(); - impl::HiFi::kernels::quantize( + cadence::impl::HiFi::kernels::quantize( out_data, input_data, 1. / scale, zero_point, numel); } else { ET_CHECK_MSG(false, "Unhandled input dtype %hhd", out.scalar_type()); @@ -52,3 +57,4 @@ void quantize_per_tensor_out( }; // namespace native }; // namespace HiFi }; // namespace impl +}; // namespace cadence diff --git a/backends/cadence/hifi/operators/quantized_layer_norm.cpp b/backends/cadence/hifi/operators/quantized_layer_norm.cpp index 62298bff09..439bb594f5 100644 --- a/backends/cadence/hifi/operators/quantized_layer_norm.cpp +++ b/backends/cadence/hifi/operators/quantized_layer_norm.cpp @@ -16,6 +16,7 @@ using executorch::aten::Tensor; using executorch::runtime::getLeadingDims; using executorch::runtime::KernelRuntimeContext; +namespace cadence { namespace impl { namespace HiFi { namespace native { @@ -76,10 +77,10 @@ void quantized_layer_norm_( for (size_t j = 0; j < last_dim; ++j) { // Since X is quantized, we dequantize it, compute fp32 result, and // quantize the result to an int8/uint8 value. - float val = impl::HiFi::kernels::dequantize( + float val = cadence::impl::HiFi::kernels::dequantize( x[j], input_scale, input_zero_point); val = (val - mean) * inv_std * weight_data[j] + bias_data[j]; - y[j] = impl::HiFi::kernels::quantize( + y[j] = cadence::impl::HiFi::kernels::quantize( val, output_inv_scale, output_zero_point); } } @@ -157,3 +158,4 @@ void quantized_layer_norm_out( }; // namespace native }; // namespace HiFi }; // namespace impl +}; // namespace cadence diff --git a/backends/cadence/hifi/operators/quantized_linear_out.cpp b/backends/cadence/hifi/operators/quantized_linear_out.cpp index 8a0fa5d420..8944a24ddb 100644 --- a/backends/cadence/hifi/operators/quantized_linear_out.cpp +++ b/backends/cadence/hifi/operators/quantized_linear_out.cpp @@ -11,6 +11,7 @@ #include #include +namespace cadence { namespace impl { namespace HiFi { namespace native { @@ -45,7 +46,7 @@ void quantized_linear_out( uint8_t* __restrict__ out_data = out.mutable_data_ptr(); // The nnlib kernel to compute quantized linear via matmul. - int32_t ret = impl::HiFi::kernels::matmul_asym8uxasym8u_asym8u( + int32_t ret = cadence::impl::HiFi::kernels::matmul_asym8uxasym8u_asym8u( out_data, // p_out weight_data, // p_mat1, in_data, // p_mat2, @@ -69,3 +70,4 @@ void quantized_linear_out( }; // namespace native }; // namespace HiFi }; // namespace impl +}; // namespace cadence diff --git a/backends/cadence/hifi/operators/targets.bzl b/backends/cadence/hifi/operators/targets.bzl index c7b24d790f..a2556476a1 100644 --- a/backends/cadence/hifi/operators/targets.bzl +++ b/backends/cadence/hifi/operators/targets.bzl @@ -26,5 +26,6 @@ def define_common_targets(): ], visibility = [ "//executorch/backends/cadence/...", + "@EXECUTORCH_CLIENTS", ], ) diff --git a/backends/cadence/hifi/third-party/nnlib/matmul_asym8uxasym8u_asym8u.cpp b/backends/cadence/hifi/third-party/nnlib/matmul_asym8uxasym8u_asym8u.cpp index 0c19e1ae59..fb944a6643 100644 --- a/backends/cadence/hifi/third-party/nnlib/matmul_asym8uxasym8u_asym8u.cpp +++ b/backends/cadence/hifi/third-party/nnlib/matmul_asym8uxasym8u_asym8u.cpp @@ -43,6 +43,7 @@ /*----------------------------Main function---------------------------------*/ +namespace cadence { namespace impl { namespace HiFi { namespace kernels { @@ -436,3 +437,4 @@ WORD32 matmul_asym8uxasym8u_asym8u( }; // namespace kernels }; // namespace HiFi }; // namespace impl +}; // namespace cadence diff --git a/backends/cadence/reference/operators/dequantize_per_tensor.cpp b/backends/cadence/reference/operators/dequantize_per_tensor.cpp index bbf427e069..aef730bfd1 100644 --- a/backends/cadence/reference/operators/dequantize_per_tensor.cpp +++ b/backends/cadence/reference/operators/dequantize_per_tensor.cpp @@ -42,7 +42,10 @@ void dequantize_per_tensor_out( impl::reference::kernels::dequantize( out_data, input_data, scale, zero_point, numel); } else { - ET_CHECK_MSG(false, "Unhandled input dtype %hhd", input.scalar_type()); + ET_CHECK_MSG( + false, + "Unhandled input dtype %hhd", + static_cast(input.scalar_type())); } } diff --git a/backends/cadence/reference/operators/quantize_per_tensor.cpp b/backends/cadence/reference/operators/quantize_per_tensor.cpp index df44171cf1..0d7ff0bc7e 100644 --- a/backends/cadence/reference/operators/quantize_per_tensor.cpp +++ b/backends/cadence/reference/operators/quantize_per_tensor.cpp @@ -44,7 +44,10 @@ void quantize_per_tensor_out( impl::reference::kernels::quantize( out_data, input_data, 1. / scale, zero_point, numel); } else { - ET_CHECK_MSG(false, "Unhandled input dtype %hhd", out.scalar_type()); + ET_CHECK_MSG( + false, + "Unhandled input dtype %hhd", + static_cast(out.scalar_type())); } } diff --git a/backends/cadence/reference/operators/quantized_layer_norm.cpp b/backends/cadence/reference/operators/quantized_layer_norm.cpp index 27b5bb7661..92b1edf3dd 100644 --- a/backends/cadence/reference/operators/quantized_layer_norm.cpp +++ b/backends/cadence/reference/operators/quantized_layer_norm.cpp @@ -145,7 +145,10 @@ void quantized_layer_norm_out( output_zero_point, out); } else { - ET_CHECK_MSG(false, "Unhandled input dtype %hhd", input.scalar_type()); + ET_CHECK_MSG( + false, + "Unhandled input dtype %hhd", + static_cast(input.scalar_type())); } } diff --git a/backends/cadence/reference/operators/quantized_relu_out.cpp b/backends/cadence/reference/operators/quantized_relu_out.cpp index 460084fcfb..19b971405c 100644 --- a/backends/cadence/reference/operators/quantized_relu_out.cpp +++ b/backends/cadence/reference/operators/quantized_relu_out.cpp @@ -68,7 +68,10 @@ void quantized_relu_out( out_shift, output); } else { - ET_CHECK_MSG(false, "Unhandled input dtype %hhd", input.scalar_type()); + ET_CHECK_MSG( + false, + "Unhandled input dtype %hhd", + static_cast(input.scalar_type())); } } diff --git a/backends/mediatek/scripts/mtk_build.sh b/backends/mediatek/scripts/mtk_build.sh index 5e6724a9b5..6c935b3c80 100755 --- a/backends/mediatek/scripts/mtk_build.sh +++ b/backends/mediatek/scripts/mtk_build.sh @@ -33,7 +33,6 @@ rm -rf cmake-android-out && mkdir cmake-android-out && cd cmake-android-out cmake -DBUCK2="$BUCK_PATH" \ -DCMAKE_TOOLCHAIN_FILE="$ANDROID_NDK/build/cmake/android.toolchain.cmake" \ -DANDROID_ABI=arm64-v8a \ - -DANDROID_PLATFORM=android-30 \ -DEXECUTORCH_BUILD_NEURON=ON \ -DNEURON_BUFFER_ALLOCATOR_LIB="$NEURON_BUFFER_ALLOCATOR_LIB" \ .. diff --git a/backends/qualcomm/utils/utils.py b/backends/qualcomm/utils/utils.py index 298664e2c9..88a84f2f9a 100644 --- a/backends/qualcomm/utils/utils.py +++ b/backends/qualcomm/utils/utils.py @@ -7,7 +7,7 @@ import operator import warnings from collections import OrderedDict -from typing import Callable, Dict, List, Set, Tuple +from typing import Callable, Dict, FrozenSet, List, Set, Tuple import executorch.backends.qualcomm.python.PyQnnManagerAdaptor as PyQnnManagerAdaptor @@ -291,9 +291,8 @@ def get_decomp_table() -> Dict[torch._ops.OperatorBase, Callable]: def _transform( - edge_program: ExportedProgram, custom_pass_config: Set[str] = None -) -> None: - custom_pass_config = custom_pass_config or {} + edge_program: ExportedProgram, custom_pass_config: FrozenSet[str] = frozenset() +) -> ExportedProgram: # currently ExirExportedProgram.transform does not accept # changes of input number which was caused by FoldQDQ # apply passes one by one here to avoid IR capture failure @@ -325,6 +324,7 @@ def _transform( edge_program.graph_module, ) edge_program._validate() + return edge_program def capture_program( diff --git a/backends/vulkan/partitioner/supported_ops.py b/backends/vulkan/partitioner/supported_ops.py index da50719ba3..83dfb3b768 100644 --- a/backends/vulkan/partitioner/supported_ops.py +++ b/backends/vulkan/partitioner/supported_ops.py @@ -84,6 +84,7 @@ def __contains__(self, op): exir_ops.edge.aten.addmm.default, exir_ops.edge.aten.linear.default, exir_ops.edge.et_vk.linear_weight_int4.default, + exir_ops.edge.aten._weight_int8pack_mm.default, # Reduction exir_ops.edge.aten._log_softmax.default, exir_ops.edge.aten._softmax.default, diff --git a/backends/vulkan/runtime/graph/ops/glsl/q_8w_linear.glsl b/backends/vulkan/runtime/graph/ops/glsl/q_8w_linear.glsl index a72df89b63..02cae3ed98 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/q_8w_linear.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/q_8w_linear.glsl @@ -44,10 +44,38 @@ layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; // This header file must be defined after the layout descriptors have been // declared because the functions in the header assume some variables have been // declared as layout descriptors. -#include "q_linear.h" #ifdef USING_BUFFER +#ifndef FLOAT_T +#define FLOAT_T float +#endif + +FLOAT_T q_8w_linear(const ivec4 out_idx, const int K) { + const FLOAT_T scale = t_scales[out_idx.x]; + + FLOAT_T outval = FLOAT_T(0.0); + + // Initial mat1 tensor idx will be (0, out_idx.y, out_idx.z, 0) + int mat1_offset = out_idx.y * mat1_strides.y + out_idx.z * qmat2_strides.z; + // Initial qmat2 tensor idx wil be (0, out_idx.x, 0, 0); note that the qmat2 + // tensor is transposed + int qmat2_offset = out_idx.x * qmat2_strides.y; + + // TODO(ssjia): optimize memory access pattern by traversing K in inner loop + for (int i = 0; i < K; i++) { + const FLOAT_T mat1_val = t_mat1[mat1_offset]; + const FLOAT_T mat2_val = t_qmat2[qmat2_offset] * scale; + + outval += mat1_val * mat2_val; + + mat1_offset++; + qmat2_offset++; + } + + return outval; +} + void main() { const int out_bufi = int(gl_GlobalInvocationID.x); if (out_bufi >= out_numel) { @@ -61,6 +89,36 @@ void main() { #else // USING_TEXTURE +VEC4_T q_8w_linear(const ivec3 out_pos, const int K) { + ivec3 mat1_pos = ivec3(0, out_pos.yz); + ivec3 qmat2_pos = ivec3(0, out_pos.x * 4, 0); + + VEC4_T outtex = VEC4_T(0); + + const ivec3 scales_pos = ivec3(out_pos.x, 0, 0); + const VEC4_T scales = load_texel(t_scales, scales_pos); + + for (int i = 0; i < K; i += 4) { + const VEC4_T mat1_tex = load_texel(t_mat1, mat1_pos); + + const VEC4_T sums = VEC4_T( + dot(mat1_tex, load_texel(t_qmat2, qmat2_pos) * scales.x), + dot(mat1_tex, + load_texel(t_qmat2, qmat2_pos + ivec3(0, 1, 0)) * scales.y), + dot(mat1_tex, + load_texel(t_qmat2, qmat2_pos + ivec3(0, 2, 0)) * scales.z), + dot(mat1_tex, + load_texel(t_qmat2, qmat2_pos + ivec3(0, 3, 0)) * scales.w)); + + outtex += sums; + + mat1_pos.x++; + qmat2_pos.x++; + } + + return outtex; +} + void main() { const ivec3 out_pos = ivec3(gl_GlobalInvocationID); if (any(greaterThanEqual(out_pos, out_limits))) { diff --git a/backends/vulkan/runtime/graph/ops/glsl/q_8w_linear_optimized.glsl b/backends/vulkan/runtime/graph/ops/glsl/q_8w_linear_optimized.glsl new file mode 100644 index 0000000000..dae2f7e3ab --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/q_8w_linear_optimized.glsl @@ -0,0 +1,211 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#version 450 core + +#define PRECISION ${PRECISION} + +#define VEC4_T ${texel_load_type(DTYPE, STORAGE)} +#define FLOAT_T ${buffer_scalar_type(DTYPE)} + +${define_active_storage_type(STORAGE)} + +${define_required_extensions(DTYPE)} +${define_required_extensions("int8")} + + +$if BATCH_MODE: + #define BATCH_MODE + +#define TILE_ROWS ${TILE_ROWS} +#define FOUR 4 + +// we avoid mat4 and vec4 usage here as they compile to much less efficient +// SPIR-V +struct FloatMatrix_2d { + float data[TILE_ROWS][FOUR]; +}; + +struct FloatMatrix_3d { + float data[TILE_ROWS][FOUR][FOUR]; +}; + +#ifdef BATCH_MODE + #define FloatMatrix FloatMatrix_3d +#else + #define FloatMatrix FloatMatrix_2d +#endif + +#include "indexing_utils.h" + +layout(std430) buffer; + +${layout_declare_tensor(0, "w", "t_out", DTYPE, STORAGE)} +${layout_declare_tensor(1, "r", "t_mat1", DTYPE, STORAGE)} +${layout_declare_tensor(2, "r", "t_qmat2", "int8", STORAGE)} +${layout_declare_tensor(3, "r", "t_scales", DTYPE, STORAGE)} + +$if STORAGE == "buffer": + ${layout_declare_ubo(4, "ivec4", "out_sizes")} + ${layout_declare_ubo(5, "ivec4", "out_strides")} + ${layout_declare_ubo(6, "int", "out_numel")} + ${layout_declare_ubo(7, "ivec4", "mat1_sizes")} + ${layout_declare_ubo(8, "ivec4", "mat1_strides")} + ${layout_declare_ubo(9, "ivec4", "qmat2_strides")} + ${layout_declare_ubo(10, "ivec4", "scales_strides")} +$else: + ${layout_declare_ubo(4, "ivec3", "out_limits")} + ${layout_declare_ubo(5, "ivec4", "mat1_sizes")} + +layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; + +// This header file must be defined after the layout descriptors have been +// declared because the functions in the header assume some variables have been +// declared as layout descriptors. + +#ifdef USING_BUFFER + +#ifndef FLOAT_T +#define FLOAT_T float +#endif + +FLOAT_T q_8w_linear(const ivec4 out_idx, const int K) { + const FLOAT_T scale = t_scales[out_idx.x]; + + FLOAT_T outval = FLOAT_T(0.0); + + // Initial mat1 tensor idx will be (0, out_idx.y, out_idx.z, 0) + int mat1_offset = out_idx.y * mat1_strides.y + out_idx.z * qmat2_strides.z; + // Initial qmat2 tensor idx wil be (0, out_idx.x, 0, 0); note that the qmat2 + // tensor is transposed + int qmat2_offset = out_idx.x * qmat2_strides.y; + + // TODO(ssjia): optimize memory access pattern by traversing K in inner loop + for (int i = 0; i < K; i++) { + const FLOAT_T mat1_val = t_mat1[mat1_offset]; + const FLOAT_T mat2_val = t_qmat2[qmat2_offset] * scale; + + outval += mat1_val * mat2_val; + + mat1_offset++; + qmat2_offset++; + } + + return outval; +} + +void main() { + const int out_bufi = int(gl_GlobalInvocationID.x); + if (out_bufi >= out_numel) { + return; + } + + const ivec4 out_tidx = bufi_to_tidx(out_bufi, out_strides, 0); + + t_out[out_bufi] = q_8w_linear(out_tidx, mat1_sizes.x); +} + +#else // USING_TEXTURE +FloatMatrix q_8w_linear_optimized(const ivec3 out_idx_tl) { + FloatMatrix results; + for (int i = 0; i < TILE_ROWS; i++) { + for (int j = 0; j < FOUR; j++) { +#ifdef BATCH_MODE + for (int k = 0; k < FOUR; k++) { + results.data[i][j][k] = 0.0f; + } +#else + results.data[i][j] = 0.0f; +#endif // BATCH_MODE + } + } + + VEC4_T im_mat1_partial_load[TILE_ROWS]; + VEC4_T im_mat2_partial_load[FOUR]; + +#ifdef BATCH_MODE + for (int batch_idx = 0; batch_idx < FOUR; batch_idx++) { + if (out_idx_tl.z + batch_idx >= out_limits.z) { + break; + } +#endif + for (int k = 0; k < mat1_sizes.x; k++) { + for (int r = 0; r < TILE_ROWS; r++) { + ivec3 mat1_pos = ivec3(k, out_idx_tl.y * TILE_ROWS + r, 0); +#ifdef BATCH_MODE + mat1_pos[2] = out_idx_tl.z + batch_idx; +#endif + + im_mat1_partial_load[r] = texelFetch(t_mat1, mat1_pos, 0); + } + + for (int r = 0; r < FOUR; ++r) { + ivec3 qmat2_pos = ivec3(k, FOUR * out_idx_tl.x + r, 0); + + im_mat2_partial_load[r] = texelFetch(t_qmat2, qmat2_pos, 0); + } + + vec4 scales = texelFetch(t_scales, ivec3(out_idx_tl.x, 0, 0), 0); + + // perform partial dot products and add partial result to results + for (int out_row = 0; out_row < TILE_ROWS; out_row++) { + for (int out_col = 0; out_col < FOUR; out_col++) { +#ifdef BATCH_MODE + results.data[out_row][out_col][batch_idx] += +#else + results.data[out_row][out_col] += +#endif + dot(im_mat1_partial_load[out_row], + im_mat2_partial_load[out_col] * scales[out_col]); + } + } + } +#ifdef BATCH_MODE + } +#endif + return results; +} + +void main() { + const ivec3 out_idx = ivec3(gl_GlobalInvocationID); + if (any(greaterThanEqual(out_idx, out_limits))) { + return; + } + + FloatMatrix results = q_8w_linear_optimized(out_idx); + + ivec3 out_pos = ivec3( + out_idx.x, + out_idx.y * TILE_ROWS, +#ifdef BATCH_MODE + out_idx.z * 4 +#else + out_idx.z +#endif +); + + for (int idx_c = 0; idx_c < TILE_ROWS; idx_c++, out_pos[1]++) { + out_pos.x = out_idx.x; + $if BATCH_MODE: + for (int idx_r = 0; idx_r < FOUR; idx_r++, out_pos[0]++) { + write_texel(t_out, out_pos, VEC4_T( + results.data[idx_c][idx_r][0], + results.data[idx_c][idx_r][1], + results.data[idx_c][idx_r][2], + results.data[idx_c][idx_r][3])); + } + $else: + write_texel(t_out, out_pos, VEC4_T( + results.data[idx_c][0], + results.data[idx_c][1], + results.data[idx_c][2], + results.data[idx_c][3])); + } +} + +#endif diff --git a/backends/vulkan/runtime/graph/ops/glsl/q_8w_linear_optimized.yaml b/backends/vulkan/runtime/graph/ops/glsl/q_8w_linear_optimized.yaml new file mode 100644 index 0000000000..52bebf9012 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/q_8w_linear_optimized.yaml @@ -0,0 +1,35 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +q_8w_linear_optimized: + parameter_names_with_default_values: + DTYPE: float + STORAGE: texture3d + MAT1_PACKING: W_packed + MAT2_PACKING: W_packed + BATCH_MODE: false + TILE_ROWS: 4 + generate_variant_forall: + TILE_ROWS: + - VALUE: 4 + SUFFIX: tile_row_4 + - VALUE: 2 + SUFFIX: tile_row_2 + DTYPE: + - VALUE: float + - VALUE: half + STORAGE: + - VALUE: texture3d + - VALUE: buffer + shader_variants: + - NAME: q_8w_linear_optimized_W_packed_W_packed + - NAME: q_8w_linear_optimized_W_packed_H_packed + MAT2_PACKING: H_packed + - NAME: batch_q_8w_linear_optimized_W_packed_W_packed + BATCH_MODE: true + - NAME: batch_q_8w_linear_optimized_W_packed_H_packed + MAT2_PACKING: H_packed + BATCH_MODE: true diff --git a/backends/vulkan/runtime/graph/ops/glsl/q_linear.h b/backends/vulkan/runtime/graph/ops/glsl/q_linear.h deleted file mode 100644 index f6de1e6dcf..0000000000 --- a/backends/vulkan/runtime/graph/ops/glsl/q_linear.h +++ /dev/null @@ -1,82 +0,0 @@ -/* - * Copyright (c) Meta Platforms, Inc. and affiliates. - * All rights reserved. - * - * This source code is licensed under the BSD-style license found in the - * LICENSE file in the root directory of this source tree. - */ - -#ifndef Q_LINEAR_H -#define Q_LINEAR_H - -#include "indexing_utils.h" - -// The functions in this file assume that some variables have been defined as -// descriptors, such as t_mat1, t_qmat2, t_scales, etc. - -#ifdef USING_BUFFER - -#ifndef FLOAT_T -#define FLOAT_T float -#endif - -FLOAT_T q_8w_linear(const ivec4 out_idx, const int K) { - const FLOAT_T scale = t_scales[out_idx.x]; - - FLOAT_T outval = FLOAT_T(0.0); - - // Initial mat1 tensor idx will be (0, out_idx.y, out_idx.z, 0) - int mat1_offset = out_idx.y * mat1_strides.y + out_idx.z * qmat2_strides.z; - // Initial qmat2 tensor idx wil be (0, out_idx.x, 0, 0); note that the qmat2 - // tensor is transposed - int qmat2_offset = out_idx.x * qmat2_strides.y; - - // TODO(ssjia): optimize memory access pattern by traversing K in inner loop - for (int i = 0; i < K; i++) { - const FLOAT_T mat1_val = t_mat1[mat1_offset]; - const FLOAT_T mat2_val = t_qmat2[qmat2_offset] * scale; - - outval += mat1_val * mat2_val; - - mat1_offset++; - qmat2_offset++; - } - - return outval; -} - -#else // USING_TEXTURE - -VEC4_T q_8w_linear(const ivec3 out_pos, const int K) { - ivec3 mat1_pos = ivec3(0, out_pos.yz); - ivec3 qmat2_pos = ivec3(0, out_pos.x * 4, 0); - - VEC4_T outtex = VEC4_T(0); - - const ivec3 scales_pos = ivec3(out_pos.x, 0, 0); - const VEC4_T scales = load_texel(t_scales, scales_pos); - - for (int i = 0; i < K; i += 4) { - const VEC4_T mat1_tex = load_texel(t_mat1, mat1_pos); - - const VEC4_T sums = VEC4_T( - dot(mat1_tex, load_texel(t_qmat2, qmat2_pos) * scales.x), - dot(mat1_tex, - load_texel(t_qmat2, qmat2_pos + ivec3(0, 1, 0)) * scales.y), - dot(mat1_tex, - load_texel(t_qmat2, qmat2_pos + ivec3(0, 2, 0)) * scales.z), - dot(mat1_tex, - load_texel(t_qmat2, qmat2_pos + ivec3(0, 3, 0)) * scales.w)); - - outtex += sums; - - mat1_pos.x++; - qmat2_pos.x++; - } - - return outtex; -} - -#endif // USING_BUFFER - -#endif // Q_LINEAR_H diff --git a/backends/vulkan/runtime/graph/ops/impl/QuantizedLinear.cpp b/backends/vulkan/runtime/graph/ops/impl/QuantizedLinear.cpp index 4dd55be469..5642976b7f 100644 --- a/backends/vulkan/runtime/graph/ops/impl/QuantizedLinear.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/QuantizedLinear.cpp @@ -11,7 +11,6 @@ #include #include - #include namespace vkcompute { @@ -130,6 +129,94 @@ void add_q_8w_linear_node( } } +void add_q_8w_linear_optimized_node( + ComputeGraph& graph, + const ValueRef mat1, + const ValueRef q_mat2_data, + const ValueRef scales_data, + const ValueRef out) { + auto viewFn = VK_GET_OP_FN("aten.view_copy.default"); + ValueRef mat1_W_packed = mat1; + ValueRef out_W_packed = out; + if (!graph.is_buffer_storage(out) && + graph.packed_dim_of(mat1) != WHCN::kWidthDim) { + // Ensure mat1 is width packed + mat1_W_packed = graph.add_tensor_like(mat1, utils::kWidthPacked); + viewFn(graph, {mat1, graph.add_none(), mat1_W_packed}); + // Ensure out is packed correctly + out_W_packed = graph.add_tensor_like(out, utils::kWidthPacked); + } + ValueRef q_mat2 = + prepack_if_tensor_ref(graph, q_mat2_data, utils::kWidthPacked); + ValueRef scales = + prepack_if_tensor_ref(graph, scales_data, utils::kWidthPacked); + + std::string kernel_name = "q_8w_linear_optimized"; + kernel_name.reserve(kShaderNameReserve); + add_packed_dim_suffix(kernel_name, graph.packed_dim_of(mat1_W_packed)); + add_packed_dim_suffix(kernel_name, graph.packed_dim_of(q_mat2)); + std::vector mat1_sizes = graph.sizes_of(mat1_W_packed); + const int mat1_dims = mat1_sizes.size(); + if (mat1_dims == 3) { + kernel_name = "batch_" + kernel_name; + } + if (mat1_sizes.at(mat1_dims - 2) < 8) { + kernel_name += "_tile_row_2"; + } else { + kernel_name += "_tile_row_4"; + } + + add_dtype_suffix(kernel_name, graph.dtype_of(out_W_packed)); + add_storage_type_suffix(kernel_name, graph.storage_type_of(out_W_packed)); + + vkapi::ParamsBindList ubos({}); + + utils::uvec3 global_size; + utils::uvec3 local_size; + if (graph.is_buffer_storage(out)) { + ubos.append( + {graph.sizes_ubo(out_W_packed), + graph.strides_ubo(out_W_packed), + graph.numel_ubo(out_W_packed), + graph.sizes_ubo(mat1_W_packed), + graph.strides_ubo(mat1_W_packed), + graph.strides_ubo(q_mat2), + graph.strides_ubo(scales)}); + global_size = graph.create_global_wg_size(out_W_packed); + local_size = graph.create_local_wg_size(out_W_packed); + } else { + global_size = graph.logical_limits_of(out_W_packed); + ubos.append( + {graph.logical_limits_ubo(out_W_packed), + graph.sizes_ubo(mat1_W_packed)}); + if (mat1_sizes.at(mat1_dims - 2) < 8) { + global_size = global_size = utils::divup_vec(global_size, {1, 2, 1}); + } else { + global_size = utils::divup_vec(global_size, {1, 4, 1}); + } + local_size = {16, 3, 1}; + } + + graph.execute_nodes().emplace_back(new DispatchNode( + graph, + VK_KERNEL_FROM_STR(kernel_name), + global_size, + local_size, + // Inputs and Outputs + {{out_W_packed, vkapi::MemoryAccessType::WRITE}, + {{mat1_W_packed, q_mat2, scales}, vkapi::MemoryAccessType::READ}}, + // Shader params buffers + ubos, + // Specialization Constants + {}, // spec_vars, + // Resizing Logic + resize_q_8w_linear_node)); + + if (!graph.is_buffer_storage(out)) { + viewFn(graph, {out_W_packed, graph.add_none(), out}); + } +} + void weight_int8pack_mm( ComputeGraph& graph, const std::vector& args) { diff --git a/build/build_android_llm_demo.sh b/build/build_android_llm_demo.sh index 4ad7c70c39..4eb47c7d05 100644 --- a/build/build_android_llm_demo.sh +++ b/build/build_android_llm_demo.sh @@ -37,7 +37,6 @@ build_android_native_library() { cmake . -DCMAKE_INSTALL_PREFIX="${CMAKE_OUT}" \ -DCMAKE_TOOLCHAIN_FILE="${ANDROID_NDK}/build/cmake/android.toolchain.cmake" \ -DANDROID_ABI="${ANDROID_ABI}" \ - -DANDROID_PLATFORM=android-26 \ -DEXECUTORCH_ENABLE_LOGGING=ON \ -DEXECUTORCH_LOG_LEVEL=Info \ -DEXECUTORCH_BUILD_XNNPACK=ON \ @@ -66,7 +65,6 @@ build_android_native_library() { cmake extension/android \ -DCMAKE_TOOLCHAIN_FILE=${ANDROID_NDK}/build/cmake/android.toolchain.cmake \ -DANDROID_ABI="${ANDROID_ABI}" \ - -DANDROID_PLATFORM=android-26 \ -DCMAKE_INSTALL_PREFIX="${CMAKE_OUT}" \ -DEXECUTORCH_ENABLE_LOGGING=ON \ -DEXECUTORCH_LOG_LEVEL=Info \ diff --git a/docs/source/index.rst b/docs/source/index.rst index cf54fa2477..0f8ebedd9f 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -135,6 +135,7 @@ Topics in this section will help you get started with ExecuTorch. export-to-executorch-api-reference executorch-runtime-api-reference + runtime-python-api-reference api-life-cycle .. toctree:: diff --git a/docs/source/llm/build-run-llama3-qualcomm-ai-engine-direct-backend.md b/docs/source/llm/build-run-llama3-qualcomm-ai-engine-direct-backend.md index 0157668d7f..d928377ff2 100644 --- a/docs/source/llm/build-run-llama3-qualcomm-ai-engine-direct-backend.md +++ b/docs/source/llm/build-run-llama3-qualcomm-ai-engine-direct-backend.md @@ -49,7 +49,6 @@ llama3/Meta-Llama-3-8B-Instruct/tokenizer.model -p -c -c `__ for how to convert a PyTorch ``nn.Module`` to an ExecuTorch ``.pte`` program file. Execution accepts and returns ``torch.Tensor`` values, making it a quick way to validate the correctness of the program. + +For detailed information on how APIs evolve and the deprecation process, please refer to the `ExecuTorch API Life Cycle and Deprecation Policy `__. + +.. automodule:: executorch.runtime +.. autoclass:: Runtime + :members: get, load_program + +.. autoclass:: OperatorRegistry + :members: operator_names + +.. autoclass:: Program + :members: method_names, load_method + +.. autoclass:: Method + :members: execute, metadata diff --git a/examples/devtools/example_runner/example_runner.cpp b/examples/devtools/example_runner/example_runner.cpp index 1aae0f2a98..1b3f1fd640 100644 --- a/examples/devtools/example_runner/example_runner.cpp +++ b/examples/devtools/example_runner/example_runner.cpp @@ -221,7 +221,7 @@ int main(int argc, char** argv) { method.ok(), "Loading of method %s failed with status 0x%" PRIx32, method_name, - method.error()); + static_cast(method.error())); ET_LOG(Info, "Method loaded."); void* debug_buffer = malloc(FLAGS_debug_buffer_size); @@ -242,7 +242,7 @@ int main(int argc, char** argv) { ET_CHECK_MSG( status == Error::Ok, "LoadBundledInput failed with status 0x%" PRIx32, - status); + static_cast(status)); ET_LOG(Info, "Inputs prepared."); @@ -252,7 +252,7 @@ int main(int argc, char** argv) { status == Error::Ok, "Execution of method %s failed with status 0x%" PRIx32, method_name, - status); + static_cast(status)); ET_LOG(Info, "Model executed successfully."); // Print the outputs. @@ -294,7 +294,7 @@ int main(int argc, char** argv) { ET_CHECK_MSG( status == Error::Ok, "Bundle verification failed with status 0x%" PRIx32, - status); + static_cast(status)); ET_LOG(Info, "Model verified successfully."); } diff --git a/examples/mediatek/mtk_build_examples.sh b/examples/mediatek/mtk_build_examples.sh index dd0caec264..df70489cf2 100755 --- a/examples/mediatek/mtk_build_examples.sh +++ b/examples/mediatek/mtk_build_examples.sh @@ -39,7 +39,6 @@ main() { -DBUCK2="$BUCK_PATH" \ -DCMAKE_TOOLCHAIN_FILE="$ANDROID_NDK/build/cmake/android.toolchain.cmake" \ -DANDROID_ABI=arm64-v8a \ - -DANDROID_PLATFORM=android-30 \ -DANDROID_NATIVE_API_LEVEL=23 \ -DEXECUTORCH_BUILD_NEURON=ON \ -DNEURON_BUFFER_ALLOCATOR_LIB="$NEURON_BUFFER_ALLOCATOR_LIB" \ @@ -59,7 +58,6 @@ main() { cmake -DCMAKE_PREFIX_PATH="${cmake_prefix_path}" \ -DCMAKE_TOOLCHAIN_FILE="$ANDROID_NDK/build/cmake/android.toolchain.cmake" \ -DANDROID_ABI=arm64-v8a \ - -DANDROID_PLATFORM=android-30 \ -DCMAKE_FIND_ROOT_PATH_MODE_PACKAGE=BOTH \ -DNEURON_BUFFER_ALLOCATOR_LIB="$NEURON_BUFFER_ALLOCATOR_LIB" \ -B"${example_build_dir}" \ diff --git a/examples/qualcomm/utils.py b/examples/qualcomm/utils.py index 27c9db2ffc..06225be2d1 100755 --- a/examples/qualcomm/utils.py +++ b/examples/qualcomm/utils.py @@ -234,7 +234,7 @@ def build_executorch_binary( shared_buffer=False, metadata=None, dump_intermediate_outputs=False, - custom_pass_config=None, + custom_pass_config=frozenset(), ): if quant_dtype is not None: quantizer = custom_quantizer or make_quantizer(quant_dtype=quant_dtype) diff --git a/examples/xnnpack/__init__.py b/examples/xnnpack/__init__.py index 81404dcf6b..d8de9f6a36 100644 --- a/examples/xnnpack/__init__.py +++ b/examples/xnnpack/__init__.py @@ -4,6 +4,8 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +# pyre-unsafe + from dataclasses import dataclass @@ -24,14 +26,14 @@ class XNNPACKOptions(object): "mv3": XNNPACKOptions(True, True), "resnet18": XNNPACKOptions(True, True), "resnet50": XNNPACKOptions(True, True), - "vit": XNNPACKOptions(False, True), # T161242362 + "vit": XNNPACKOptions(True, True), "w2l": XNNPACKOptions(True, True), "edsr": XNNPACKOptions(True, True), - "mobilebert": XNNPACKOptions(False, True), # T197452682 + "mobilebert": XNNPACKOptions(True, True), "llama2": XNNPACKOptions(False, True), "emformer_join": XNNPACKOptions(True, True), - "emformer_predict": XNNPACKOptions(False, True), # T197457838 - "emformer_transcribe": XNNPACKOptions(False, True), # T197449765 + "emformer_predict": XNNPACKOptions(True, True), + "emformer_transcribe": XNNPACKOptions(True, True), } diff --git a/extension/android/CMakeLists.txt b/extension/android/CMakeLists.txt index 0ee8b042a2..8f0e67900c 100644 --- a/extension/android/CMakeLists.txt +++ b/extension/android/CMakeLists.txt @@ -20,6 +20,9 @@ set(EXECUTORCH_ROOT "${CMAKE_CURRENT_SOURCE_DIR}/../..") include(${EXECUTORCH_ROOT}/build/Utils.cmake) set(_common_compile_options -Wno-deprecated-declarations -fPIC) set(_common_include_directories ${EXECUTORCH_ROOT}/..) +if(NOT ANDROID_PLATFORM) + set(ANDROID_PLATFORM android-30) +endif() # We need to download fbjni library from maven, and use its "prefab" library # and headers, and link executorch library against that fbjni library. diff --git a/extension/benchmark/README.md b/extension/benchmark/README.md index 1190a85cfe..a9918864e9 100644 --- a/extension/benchmark/README.md +++ b/extension/benchmark/README.md @@ -61,7 +61,7 @@ Users can schedule a benchmarking workflow on a pull request through GitHub Acti ## Retrieving Benchmark Results -Currently, retrieving benchmark results involves manually extracting the `benchmark_results.json` from the `Customer_Artifacts.zip` stored on AWS S3 from the benchmarking job. This process is not yet streamlined. We are working on simplifying this process and linking the results directly to the dashboard, which will be available soon. +The easiest way to view benchmark results is on the [dashboard](./README.md#dashboard), while raw results for individual configurations can be manually accessed by downloading the `Customer_Artifacts.zip` from the CI. ## Feedback and Issue Reporting diff --git a/extension/benchmark/android/benchmark/app/src/main/java/org/pytorch/minibench/BenchmarkActivity.java b/extension/benchmark/android/benchmark/app/src/main/java/org/pytorch/minibench/BenchmarkActivity.java index e2d46f8e8d..15f527475b 100644 --- a/extension/benchmark/android/benchmark/app/src/main/java/org/pytorch/minibench/BenchmarkActivity.java +++ b/extension/benchmark/android/benchmark/app/src/main/java/org/pytorch/minibench/BenchmarkActivity.java @@ -10,6 +10,7 @@ import android.app.Activity; import android.content.Intent; +import android.os.AsyncTask; import android.os.Bundle; import android.system.ErrnoException; import android.system.Os; @@ -47,43 +48,57 @@ protected void onCreate(Bundle savedInstanceState) { // TODO: Format the string with a parsable format Stats stats = new Stats(); - // Record the time it takes to load the model and the forward method - stats.loadStart = System.nanoTime(); - Module module = Module.load(model.getPath()); - stats.errorCode = module.loadMethod("forward"); - stats.loadEnd = System.nanoTime(); + new AsyncTask() { + @Override + protected Void doInBackground(Void... voids) { - for (int i = 0; i < numIter; i++) { - long start = System.nanoTime(); - module.forward(); - double forwardMs = (System.nanoTime() - start) * 1e-6; - stats.latency.add(forwardMs); - } + // Record the time it takes to load the model and the forward method + stats.loadStart = System.nanoTime(); + Module module = Module.load(model.getPath()); + stats.errorCode = module.loadMethod("forward"); + stats.loadEnd = System.nanoTime(); - final BenchmarkMetric.BenchmarkModel benchmarkModel = - BenchmarkMetric.extractBackendAndQuantization(model.getName().replace(".pte", "")); - final List results = new ArrayList<>(); - // The list of metrics we have atm includes: - // Avg inference latency after N iterations - results.add( - new BenchmarkMetric( - benchmarkModel, - "avg_inference_latency(ms)", - stats.latency.stream().mapToDouble(l -> l).average().orElse(0.0f), - 0.0f)); - // Model load time - results.add( - new BenchmarkMetric( - benchmarkModel, "model_load_time(ms)", (stats.loadEnd - stats.loadStart) * 1e-6, 0.0f)); - // Load status - results.add(new BenchmarkMetric(benchmarkModel, "load_status", stats.errorCode, 0)); + for (int i = 0; i < numIter; i++) { + long start = System.nanoTime(); + module.forward(); + double forwardMs = (System.nanoTime() - start) * 1e-6; + stats.latency.add(forwardMs); + } + return null; + } - try (FileWriter writer = new FileWriter(getFilesDir() + "/benchmark_results.json")) { - Gson gson = new Gson(); - writer.write(gson.toJson(results)); - } catch (IOException e) { - e.printStackTrace(); - } + @Override + protected void onPostExecute(Void aVoid) { + + final BenchmarkMetric.BenchmarkModel benchmarkModel = + BenchmarkMetric.extractBackendAndQuantization(model.getName().replace(".pte", "")); + final List results = new ArrayList<>(); + // The list of metrics we have atm includes: + // Avg inference latency after N iterations + results.add( + new BenchmarkMetric( + benchmarkModel, + "avg_inference_latency(ms)", + stats.latency.stream().mapToDouble(l -> l).average().orElse(0.0f), + 0.0f)); + // Model load time + results.add( + new BenchmarkMetric( + benchmarkModel, + "model_load_time(ms)", + (stats.loadEnd - stats.loadStart) * 1e-6, + 0.0f)); + // Load status + results.add(new BenchmarkMetric(benchmarkModel, "load_status", stats.errorCode, 0)); + + try (FileWriter writer = new FileWriter(getFilesDir() + "/benchmark_results.json")) { + Gson gson = new Gson(); + writer.write(gson.toJson(results)); + } catch (IOException e) { + e.printStackTrace(); + } + } + }.execute(); } } diff --git a/runtime/__init__.py b/runtime/__init__.py index 80ffeeba03..4ed99ddae0 100644 --- a/runtime/__init__.py +++ b/runtime/__init__.py @@ -4,15 +4,15 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -"""Interface to the native C++ ExecuTorch runtime. - +""" Example usage: -.. code-block:: text + +.. code-block:: python from pathlib import Path import torch - from executorch.runtime import Verification, Runtime + from executorch.runtime import Verification, Runtime, Program, Method et_runtime: Runtime = Runtime.get() program: Program = et_runtime.load_program( @@ -28,6 +28,7 @@ print(f" outputs: {outputs}") Example output: + .. code-block:: text Program methods: ('forward', 'forward2') @@ -107,6 +108,9 @@ def __init__(self, module: ExecuTorchModule, data: Optional[bytes]) -> None: @property def method_names(self) -> Set[str]: + """ + Returns method names of the `Program` as a set of strings. + """ return set(self._methods.keys()) def load_method(self, name: str) -> Optional[Method]: @@ -130,7 +134,9 @@ def __init__(self, legacy_module: ModuleType) -> None: @property def operator_names(self) -> Set[str]: - """The names of all registered operators.""" + """ + Returns the names of all registered operators as a set of strings. + """ return set(self._legacy_module._get_operator_names()) diff --git a/runtime/executor/method.cpp b/runtime/executor/method.cpp index 0838529bc5..a05d789a80 100644 --- a/runtime/executor/method.cpp +++ b/runtime/executor/method.cpp @@ -1191,8 +1191,9 @@ Error Method::step() { static_cast(step_state_.chain_idx), static_cast(step_state_.instr_idx)); EXECUTORCH_SCOPE_PROF("Method::step"); - internal::EventTracerProfileMethodScope event_tracer_profile_scope = - internal::EventTracerProfileMethodScope(event_tracer_, "Method::step"); + EventTracerEntry event_tracer_entry = + internal::event_tracer_begin_profiling_event( + event_tracer_, "Method::step"); ET_CHECK_OR_RETURN_ERROR( initialized(), InvalidState, @@ -1218,6 +1219,7 @@ Error Method::step() { return status; } + internal::event_tracer_end_profiling_event(event_tracer_, event_tracer_entry); // end of the current chain, advance to the next chain if (step_state_.instr_idx == num_instructions) { step_state_.instr_idx = 0; @@ -1233,8 +1235,9 @@ Error Method::experimental_step() { Error Method::execute() { internal::event_tracer_create_event_block(event_tracer_, "Execute"); - internal::EventTracerProfileMethodScope event_tracer_profile_scope = - internal::EventTracerProfileMethodScope(event_tracer_, "Method::execute"); + EventTracerEntry event_tracer_entry = + internal::event_tracer_begin_profiling_event( + event_tracer_, "Method::execute"); EXECUTORCH_SCOPE_PROF("Method::execute"); ET_CHECK_OR_RETURN_ERROR( initialized(), @@ -1270,7 +1273,7 @@ Error Method::execute() { } } } - + internal::event_tracer_end_profiling_event(event_tracer_, event_tracer_entry); log_outputs(); // TODO(jakeszwe, dbort): Decide on calling execute back to back without