From 3da412c874d7b820feb6c537d18f976b24603faa Mon Sep 17 00:00:00 2001 From: Siva Rama Krishna Reddy B Date: Fri, 30 Dec 2022 13:27:47 +0530 Subject: [PATCH 1/2] [BENCHMARK][ADRENO] Adreno Benchmarks with texture Benchmarks for various networks listed below with fp16 and fp32. resnet-18, resnet-34, resnet-50, vgg-16, vgg-19, densenet-121, inception_v3, mobilenetv1, squeezenet_v1.0, squeezenet_v1.1 --- .../adreno/adreno_gpu_bench_texture.py | 277 ++++++++++++++++++ apps/benchmark/adreno/bench.py | 61 ++++ tests/scripts/ci.py | 7 + 3 files changed, 345 insertions(+) create mode 100755 apps/benchmark/adreno/adreno_gpu_bench_texture.py create mode 100755 apps/benchmark/adreno/bench.py diff --git a/apps/benchmark/adreno/adreno_gpu_bench_texture.py b/apps/benchmark/adreno/adreno_gpu_bench_texture.py new file mode 100755 index 000000000000..0fab2012251d --- /dev/null +++ b/apps/benchmark/adreno/adreno_gpu_bench_texture.py @@ -0,0 +1,277 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Benchmark script for various models on Adreno GPU. +""" +import argparse + +import numpy as np + +import os +import sys +import tvm +from tvm import te +from tvm.relay import testing +from tvm.contrib.utils import tempdir +import tvm.contrib.graph_executor as runtime +from tvm import relay +from tvm import autotvm +from tvm.contrib import utils, ndk + + +def get_network(name, batch_size, dtype="float32"): + """Get the symbol definition and random weight of a network + + Parameters + ---------- + name: str + The name of the network, can be 'resnet-18', 'resnet-50', 'vgg-16', 'inception_v3', 'mobilenet', ... + batch_size: int + batch size + dtype: str + Data type + + Returns + ------- + net: tvm.IRModule + The relay function of network definition + params: dict + The random parameters for benchmark + input_shape: tuple + The shape of input tensor + output_shape: tuple + The shape of output tensor + """ + input_shape = (batch_size, 3, 224, 224) + output_shape = (batch_size, 1000) + + if name == "mobilenet": + net, params = testing.mobilenet.get_workload(batch_size=batch_size, dtype=dtype) + elif name == "inception_v3": + input_shape = (batch_size, 3, 299, 299) + net, params = testing.inception_v3.get_workload(batch_size=batch_size, dtype=dtype) + elif "resnet" in name: + n_layer = int(name.split("-")[1]) + net, params = testing.resnet.get_workload( + num_layers=n_layer, batch_size=batch_size, dtype=dtype + ) + elif "vgg" in name: + n_layer = int(name.split("-")[1]) + net, params = testing.vgg.get_workload( + num_layers=n_layer, batch_size=batch_size, dtype=dtype + ) + elif "densenet" in name: + n_layer = int(name.split("-")[1]) + net, params = testing.densenet.get_workload( + densenet_size=n_layer, batch_size=batch_size, dtype=dtype + ) + elif "squeezenet" in name: + version = name.split("_v")[1] + net, params = testing.squeezenet.get_workload( + batch_size=batch_size, version=version, dtype=dtype + ) + elif name == "mxnet": + # an example for mxnet model + from mxnet.gluon.model_zoo.vision import get_model + + block = get_model("resnet18_v1", pretrained=True) + net, params = relay.frontend.from_mxnet(block, shape={"data": input_shape}, dtype=dtype) + net = net["main"] + net = relay.Function( + net.params, relay.nn.softmax(net.body), None, net.type_params, net.attrs + ) + net = tvm.IRModule.from_expr(net) + else: + raise ValueError("Unsupported network: " + name) + + return net, params, input_shape, output_shape + + +def print_progress(msg): + """print progress message + + Parameters + ---------- + msg: str + The message to print + """ + sys.stdout.write(msg + "\r") + sys.stdout.flush() + + +def tune_tasks( + tasks, + measure_option, + n_trial=1024, + early_stopping=None, + log_filename="tuning.log", +): + from tvm.autotvm.tuner import XGBTuner + + tmp_log_file = log_filename + ".tmp" + + for i, tsk in enumerate(reversed(tasks)): + print("Task: ", tsk) + prefix = "[Task %2d/%2d] " % (i + 1, len(tasks)) + tuner_obj = XGBTuner(tsk, loss_type="rank") + + tsk_trial = min(n_trial, len(tsk.config_space)) + tuner_obj.tune( + n_trial=tsk_trial, + early_stopping=early_stopping, + measure_option=measure_option, + callbacks=[ + autotvm.callback.progress_bar(tsk_trial, prefix=prefix), + autotvm.callback.log_to_file(tmp_log_file), + ], + ) + + autotvm.record.pick_best(tmp_log_file, log_filename) + + +def evaluate_network(network, target, target_host, dtype, repeat): + print_progress(network) + net, params, input_shape, output_shape = get_network(network, batch_size=1, dtype=dtype) + + # Auto Tuning + tune_log = "adreno-" + network + "-" + dtype + ".log" + tuning_options = { + "log_filename": tune_log, + "early_stopping": None, + "measure_option": autotvm.measure_option( + builder=autotvm.LocalBuilder(build_func=ndk.create_shared, timeout=15), + runner=autotvm.RPCRunner( + args.rpc_key, + host=args.host, + port=args.port, + number=3, + timeout=600, + ), + ), + } + if args.tune: + tasks = autotvm.task.extract_from_program( + net, target=target, target_host=target_host, params=params + ) + tune_tasks(tasks, **tuning_options) + + print_progress("%-20s building..." % network) + + # Build the tuning log + if os.path.exists(tune_log): + with autotvm.apply_history_best(tune_log): + with tvm.transform.PassContext(opt_level=3): + lib = relay.build( + net, target=tvm.target.Target(target, host=target_host), params=params + ) + else: + print("WARNING: Benchmark running with out tuning cache file - ", tune_log) + with tvm.transform.PassContext(opt_level=3): + lib = relay.build(net, target=tvm.target.Target(target, host=target_host), params=params) + + tmp = tempdir() + + filename = "%s.so" % network + lib.export_library(tmp.relpath(filename), ndk.create_shared) + + # upload library and params + print_progress("%-20s uploading..." % network) + + # connect to remote device + tracker = tvm.rpc.connect_tracker(args.host, args.port) + remote = tracker.request(args.rpc_key) + + dev = remote.device(str(target), 0) + remote.upload(tmp.relpath(filename)) + + rlib = remote.load_module(filename) + module = runtime.GraphModule(rlib["default"](dev)) + data_tvm = tvm.nd.array((np.random.uniform(size=input_shape)).astype(dtype)) + module.set_input("data", data_tvm) + + # evaluate + print_progress("%-20s evaluating..." % network) + ftimer = module.module.time_evaluator("run", dev, number=1, repeat=repeat) + prof_res = np.array(ftimer().results) * 1000 # multiply 1000 for converting to millisecond + print( + "%-20s %-19s (%s)" + % (network+"-"+dtype, "%.2f ms" % np.mean(prof_res), "%.2f ms" % np.std(prof_res)) + ) + return (np.mean(prof_res), np.std(prof_res)) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "--network", + type=str, + choices=[ + "resnet-18", + "resnet-34", + "resnet-50", + "vgg-16", + "vgg-19", + "densenet-121", + "inception_v3", + "mobilenet", + "squeezenet_v1.0", + "squeezenet_v1.1", + ], + help="The name of neural network", + ) + parser.add_argument("--host", type=str, default="127.0.0.1") + parser.add_argument("--port", type=int, default=9190) + parser.add_argument("--rpc-key", type=str, default="android") + parser.add_argument("--repeat", type=int, default=30) + parser.add_argument("--tune", type=bool, default=False) + args = parser.parse_args() + + if args.network is None: + networks = [ + "resnet-18", + "resnet-34", + "resnet-50", + "vgg-16", + "vgg-19", + "densenet-121", + "inception_v3", + "mobilenet", + "squeezenet_v1.0", + "squeezenet_v1.1", + ] + else: + networks = [args.network] + + target = "opencl -device=adreno" + target_host = "llvm -mtriple=arm64-linux-android" + + print("--------------------------------------------------") + print("%-20s %-20s" % ("Network Name", "Mean Inference Time (std dev)")) + print("--------------------------------------------------") + + results = {} + + for network in networks: + ftime = evaluate_network(network, target, target_host, "float32", args.repeat) + results[network + "-float32"] = ftime + ftime = evaluate_network(network, target, target_host, "float16", args.repeat) + results[network + "-float16"] = ftime + + print("----------------------------------------------------------------------") + print("%-30s %-30s" % ("Network Name", "Mean Inference Time (std dev)")) + print("----------------------------------------------------------------------") + for key, val in results.items(): + print("%-30s %-30s (%s)" % (key, "%.2f ms" % val[0], "%.2f ms" % val[1])) diff --git a/apps/benchmark/adreno/bench.py b/apps/benchmark/adreno/bench.py new file mode 100755 index 000000000000..265c349d782c --- /dev/null +++ b/apps/benchmark/adreno/bench.py @@ -0,0 +1,61 @@ +#!/usr/bin/env bash +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +set -euxo pipefail + +echo "Bench called" + +source tests/scripts/setup-pytest-env.sh +export PYTHONPATH=${PYTHONPATH}:${TVM_PATH}/apps/extension/python +export LD_LIBRARY_PATH="build:${LD_LIBRARY_PATH:-}" + +export TVM_TRACKER_HOST=127.0.0.1 +export TVM_TRACKER_PORT=$(((RANDOM % 100) + 9100)) +export RPC_DEVICE_KEY="android" +export TVM_NDK_CC="${ANDROID_NDK_HOME}/toolchains/llvm/prebuilt/linux-x86_64/bin/aarch64-linux-android28-clang" + +env PYTHONPATH=python python3 -m tvm.exec.rpc_tracker --host "${TVM_TRACKER_HOST}" --port "${TVM_TRACKER_PORT}" & +TRACKER_PID=$! +sleep 5 # Wait for tracker to bind + +export ANDROID_SERIAL=$2 + +adb shell "mkdir -p /data/local/tmp/tvm_ci" +adb push build-adreno-target/tvm_rpc /data/local/tmp/tvm_ci/tvm_rpc_ci +adb push build-adreno-target/libtvm_runtime.so /data/local/tmp/tvm_ci + +adb reverse tcp:${TVM_TRACKER_PORT} tcp:${TVM_TRACKER_PORT} +adb forward tcp:5000 tcp:5000 +adb forward tcp:5001 tcp:5001 +adb forward tcp:5002 tcp:5002 +env adb shell "cd /data/local/tmp/tvm_ci; killall -9 tvm_rpc_ci; sleep 2; LD_LIBRARY_PATH=/data/local/tmp/tvm_ci/ ./tvm_rpc_ci server --host=0.0.0.0 --port=5000 --port-end=5010 --tracker=127.0.0.1:${TVM_TRACKER_PORT} --key=${RPC_DEVICE_KEY}" & +DEVICE_PID=$! +sleep 5 # Wait for the device connections +trap "{ kill ${TRACKER_PID}; kill ${DEVICE_PID}; }" 0 + +# cleanup pycache +find . -type f -path "*.pyc" | xargs rm -f +# Test TVM +make cython3 + +if [ "texture" == $1 ] ; then + python3 apps/benchmark/adreno/adreno_gpu_bench_texture.py --host ${TVM_TRACKER_HOST} --port ${TVM_TRACKER_PORT} --rpc-key ${RPC_DEVICE_KEY} +fi + +kill ${TRACKER_PID} +kill ${DEVICE_PID} diff --git a/tests/scripts/ci.py b/tests/scripts/ci.py index 16389d29354c..25fed3c3ab04 100755 --- a/tests/scripts/ci.py +++ b/tests/scripts/ci.py @@ -727,6 +727,13 @@ def add_subparser( "./tests/scripts/task_python_adreno.sh " + os.environ.get("ANDROID_SERIAL", ""), ], ), + "benchmarks": ( + "run Adreno Texture Benchmarks", + [ + "./apps/benchmark/adreno/bench.py texture " + + os.environ.get("ANDROID_SERIAL", ""), + ], + ), }, ), ] From e8b9f2350049fd1e81acdc87c9f8384ae59a70f9 Mon Sep 17 00:00:00 2001 From: Siva Rama Krishna Reddy B Date: Sat, 31 Dec 2022 07:42:25 +0530 Subject: [PATCH 2/2] * lint error --- apps/benchmark/adreno/adreno_gpu_bench_texture.py | 7 ++++--- apps/benchmark/adreno/{bench.py => bench.sh} | 2 -- tests/scripts/ci.py | 2 +- 3 files changed, 5 insertions(+), 6 deletions(-) rename apps/benchmark/adreno/{bench.py => bench.sh} (99%) diff --git a/apps/benchmark/adreno/adreno_gpu_bench_texture.py b/apps/benchmark/adreno/adreno_gpu_bench_texture.py index 0fab2012251d..2228cda31a39 100755 --- a/apps/benchmark/adreno/adreno_gpu_bench_texture.py +++ b/apps/benchmark/adreno/adreno_gpu_bench_texture.py @@ -178,9 +178,10 @@ def evaluate_network(network, target, target_host, dtype, repeat): net, target=tvm.target.Target(target, host=target_host), params=params ) else: - print("WARNING: Benchmark running with out tuning cache file - ", tune_log) with tvm.transform.PassContext(opt_level=3): - lib = relay.build(net, target=tvm.target.Target(target, host=target_host), params=params) + lib = relay.build( + net, target=tvm.target.Target(target, host=target_host), params=params + ) tmp = tempdir() @@ -208,7 +209,7 @@ def evaluate_network(network, target, target_host, dtype, repeat): prof_res = np.array(ftimer().results) * 1000 # multiply 1000 for converting to millisecond print( "%-20s %-19s (%s)" - % (network+"-"+dtype, "%.2f ms" % np.mean(prof_res), "%.2f ms" % np.std(prof_res)) + % (network + "-" + dtype, "%.2f ms" % np.mean(prof_res), "%.2f ms" % np.std(prof_res)) ) return (np.mean(prof_res), np.std(prof_res)) diff --git a/apps/benchmark/adreno/bench.py b/apps/benchmark/adreno/bench.sh similarity index 99% rename from apps/benchmark/adreno/bench.py rename to apps/benchmark/adreno/bench.sh index 265c349d782c..7d46685b8654 100755 --- a/apps/benchmark/adreno/bench.py +++ b/apps/benchmark/adreno/bench.sh @@ -18,8 +18,6 @@ set -euxo pipefail -echo "Bench called" - source tests/scripts/setup-pytest-env.sh export PYTHONPATH=${PYTHONPATH}:${TVM_PATH}/apps/extension/python export LD_LIBRARY_PATH="build:${LD_LIBRARY_PATH:-}" diff --git a/tests/scripts/ci.py b/tests/scripts/ci.py index 25fed3c3ab04..756b269d0e50 100755 --- a/tests/scripts/ci.py +++ b/tests/scripts/ci.py @@ -730,7 +730,7 @@ def add_subparser( "benchmarks": ( "run Adreno Texture Benchmarks", [ - "./apps/benchmark/adreno/bench.py texture " + "./apps/benchmark/adreno/bench.sh texture " + os.environ.get("ANDROID_SERIAL", ""), ], ),