From 19cbafea076de75ed4b75cfc2800a26f61b22860 Mon Sep 17 00:00:00 2001 From: Hansong Zhang Date: Wed, 16 Oct 2024 23:11:03 -0700 Subject: [PATCH] [minibench] Run benchmark on background thread --- .../pytorch/minibench/BenchmarkActivity.java | 83 +++++++++++-------- 1 file changed, 49 insertions(+), 34 deletions(-) diff --git a/extension/benchmark/android/benchmark/app/src/main/java/org/pytorch/minibench/BenchmarkActivity.java b/extension/benchmark/android/benchmark/app/src/main/java/org/pytorch/minibench/BenchmarkActivity.java index e2d46f8e8d..15f527475b 100644 --- a/extension/benchmark/android/benchmark/app/src/main/java/org/pytorch/minibench/BenchmarkActivity.java +++ b/extension/benchmark/android/benchmark/app/src/main/java/org/pytorch/minibench/BenchmarkActivity.java @@ -10,6 +10,7 @@ import android.app.Activity; import android.content.Intent; +import android.os.AsyncTask; import android.os.Bundle; import android.system.ErrnoException; import android.system.Os; @@ -47,43 +48,57 @@ protected void onCreate(Bundle savedInstanceState) { // TODO: Format the string with a parsable format Stats stats = new Stats(); - // Record the time it takes to load the model and the forward method - stats.loadStart = System.nanoTime(); - Module module = Module.load(model.getPath()); - stats.errorCode = module.loadMethod("forward"); - stats.loadEnd = System.nanoTime(); + new AsyncTask() { + @Override + protected Void doInBackground(Void... voids) { - for (int i = 0; i < numIter; i++) { - long start = System.nanoTime(); - module.forward(); - double forwardMs = (System.nanoTime() - start) * 1e-6; - stats.latency.add(forwardMs); - } + // Record the time it takes to load the model and the forward method + stats.loadStart = System.nanoTime(); + Module module = Module.load(model.getPath()); + stats.errorCode = module.loadMethod("forward"); + stats.loadEnd = System.nanoTime(); - final BenchmarkMetric.BenchmarkModel benchmarkModel = - BenchmarkMetric.extractBackendAndQuantization(model.getName().replace(".pte", "")); - final List results = new ArrayList<>(); - // The list of metrics we have atm includes: - // Avg inference latency after N iterations - results.add( - new BenchmarkMetric( - benchmarkModel, - "avg_inference_latency(ms)", - stats.latency.stream().mapToDouble(l -> l).average().orElse(0.0f), - 0.0f)); - // Model load time - results.add( - new BenchmarkMetric( - benchmarkModel, "model_load_time(ms)", (stats.loadEnd - stats.loadStart) * 1e-6, 0.0f)); - // Load status - results.add(new BenchmarkMetric(benchmarkModel, "load_status", stats.errorCode, 0)); + for (int i = 0; i < numIter; i++) { + long start = System.nanoTime(); + module.forward(); + double forwardMs = (System.nanoTime() - start) * 1e-6; + stats.latency.add(forwardMs); + } + return null; + } - try (FileWriter writer = new FileWriter(getFilesDir() + "/benchmark_results.json")) { - Gson gson = new Gson(); - writer.write(gson.toJson(results)); - } catch (IOException e) { - e.printStackTrace(); - } + @Override + protected void onPostExecute(Void aVoid) { + + final BenchmarkMetric.BenchmarkModel benchmarkModel = + BenchmarkMetric.extractBackendAndQuantization(model.getName().replace(".pte", "")); + final List results = new ArrayList<>(); + // The list of metrics we have atm includes: + // Avg inference latency after N iterations + results.add( + new BenchmarkMetric( + benchmarkModel, + "avg_inference_latency(ms)", + stats.latency.stream().mapToDouble(l -> l).average().orElse(0.0f), + 0.0f)); + // Model load time + results.add( + new BenchmarkMetric( + benchmarkModel, + "model_load_time(ms)", + (stats.loadEnd - stats.loadStart) * 1e-6, + 0.0f)); + // Load status + results.add(new BenchmarkMetric(benchmarkModel, "load_status", stats.errorCode, 0)); + + try (FileWriter writer = new FileWriter(getFilesDir() + "/benchmark_results.json")) { + Gson gson = new Gson(); + writer.write(gson.toJson(results)); + } catch (IOException e) { + e.printStackTrace(); + } + } + }.execute(); } }