Skip to content

Commit

Permalink
[minibench] Run benchmark on background thread
Browse files Browse the repository at this point in the history
  • Loading branch information
kirklandsign committed Oct 17, 2024
1 parent ec1c431 commit 19cbafe
Showing 1 changed file with 49 additions and 34 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

import android.app.Activity;
import android.content.Intent;
import android.os.AsyncTask;
import android.os.Bundle;
import android.system.ErrnoException;
import android.system.Os;
Expand Down Expand Up @@ -47,43 +48,57 @@ protected void onCreate(Bundle savedInstanceState) {
// TODO: Format the string with a parsable format
Stats stats = new Stats();

// Record the time it takes to load the model and the forward method
stats.loadStart = System.nanoTime();
Module module = Module.load(model.getPath());
stats.errorCode = module.loadMethod("forward");
stats.loadEnd = System.nanoTime();
new AsyncTask<Void, Void, Void>() {
@Override
protected Void doInBackground(Void... voids) {

for (int i = 0; i < numIter; i++) {
long start = System.nanoTime();
module.forward();
double forwardMs = (System.nanoTime() - start) * 1e-6;
stats.latency.add(forwardMs);
}
// Record the time it takes to load the model and the forward method
stats.loadStart = System.nanoTime();
Module module = Module.load(model.getPath());
stats.errorCode = module.loadMethod("forward");
stats.loadEnd = System.nanoTime();

final BenchmarkMetric.BenchmarkModel benchmarkModel =
BenchmarkMetric.extractBackendAndQuantization(model.getName().replace(".pte", ""));
final List<BenchmarkMetric> results = new ArrayList<>();
// The list of metrics we have atm includes:
// Avg inference latency after N iterations
results.add(
new BenchmarkMetric(
benchmarkModel,
"avg_inference_latency(ms)",
stats.latency.stream().mapToDouble(l -> l).average().orElse(0.0f),
0.0f));
// Model load time
results.add(
new BenchmarkMetric(
benchmarkModel, "model_load_time(ms)", (stats.loadEnd - stats.loadStart) * 1e-6, 0.0f));
// Load status
results.add(new BenchmarkMetric(benchmarkModel, "load_status", stats.errorCode, 0));
for (int i = 0; i < numIter; i++) {
long start = System.nanoTime();
module.forward();
double forwardMs = (System.nanoTime() - start) * 1e-6;
stats.latency.add(forwardMs);
}
return null;
}

try (FileWriter writer = new FileWriter(getFilesDir() + "/benchmark_results.json")) {
Gson gson = new Gson();
writer.write(gson.toJson(results));
} catch (IOException e) {
e.printStackTrace();
}
@Override
protected void onPostExecute(Void aVoid) {

final BenchmarkMetric.BenchmarkModel benchmarkModel =
BenchmarkMetric.extractBackendAndQuantization(model.getName().replace(".pte", ""));
final List<BenchmarkMetric> results = new ArrayList<>();
// The list of metrics we have atm includes:
// Avg inference latency after N iterations
results.add(
new BenchmarkMetric(
benchmarkModel,
"avg_inference_latency(ms)",
stats.latency.stream().mapToDouble(l -> l).average().orElse(0.0f),
0.0f));
// Model load time
results.add(
new BenchmarkMetric(
benchmarkModel,
"model_load_time(ms)",
(stats.loadEnd - stats.loadStart) * 1e-6,
0.0f));
// Load status
results.add(new BenchmarkMetric(benchmarkModel, "load_status", stats.errorCode, 0));

try (FileWriter writer = new FileWriter(getFilesDir() + "/benchmark_results.json")) {
Gson gson = new Gson();
writer.write(gson.toJson(results));
} catch (IOException e) {
e.printStackTrace();
}
}
}.execute();
}
}

Expand Down

0 comments on commit 19cbafe

Please sign in to comment.