diff --git a/api/src/main/java/ai/djl/inference/Predictor.java b/api/src/main/java/ai/djl/inference/Predictor.java index d9b20e3ef9e..4de441faa98 100644 --- a/api/src/main/java/ai/djl/inference/Predictor.java +++ b/api/src/main/java/ai/djl/inference/Predictor.java @@ -17,6 +17,7 @@ import ai.djl.inference.streaming.StreamingBlock; import ai.djl.inference.streaming.StreamingTranslator; import ai.djl.inference.streaming.StreamingTranslator.StreamOutput; +import ai.djl.metric.Dimension; import ai.djl.metric.Metrics; import ai.djl.metric.Unit; import ai.djl.ndarray.LazyNDArray; @@ -94,6 +95,7 @@ public class Predictor implements AutoCloseable { protected Metrics metrics; protected Block block; protected ParameterStore parameterStore; + protected Dimension dimension; /** * Creates a new instance of {@code BasePredictor} with the given {@link Model} and {@link @@ -116,6 +118,7 @@ public Predictor(Model model, Translator translator, Device device, boolea this.translator = translator; block = model.getBlock(); parameterStore = new ParameterStore(manager, copy); + dimension = new Dimension("Model", model.getProperty("metric_dimension", "model")); } /** @@ -315,7 +318,7 @@ private void preprocessEnd(NDList list) { long tmp = System.nanoTime(); long duration = (tmp - timestamp) / 1000; timestamp = tmp; - metrics.addMetric("Preprocess", duration, Unit.MICROSECONDS); + metrics.addMetric("Preprocess", duration, Unit.MICROSECONDS, dimension); } } @@ -325,7 +328,7 @@ private void predictEnd(NDList list) { long tmp = System.nanoTime(); long duration = (tmp - timestamp) / 1000; timestamp = tmp; - metrics.addMetric("Inference", duration, Unit.MICROSECONDS); + metrics.addMetric("Inference", duration, Unit.MICROSECONDS, dimension); } } @@ -334,8 +337,9 @@ private void postProcessEnd(long begin) { long tmp = System.nanoTime(); long duration = (tmp - timestamp) / 1000; timestamp = tmp; - metrics.addMetric("Postprocess", duration, Unit.MICROSECONDS); - metrics.addMetric("Total", (tmp - begin) / 1000, Unit.MICROSECONDS); + metrics.addMetric("Postprocess", duration, Unit.MICROSECONDS, dimension); + long prediction = (tmp - begin) / 1000; + metrics.addMetric("Prediction", prediction, Unit.MICROSECONDS, dimension); } } diff --git a/api/src/main/java/ai/djl/metric/Metrics.java b/api/src/main/java/ai/djl/metric/Metrics.java index 011909499b8..19fcc0f681d 100644 --- a/api/src/main/java/ai/djl/metric/Metrics.java +++ b/api/src/main/java/ai/djl/metric/Metrics.java @@ -103,9 +103,10 @@ public void addMetric(String name, Number value) { * @param name the metric name * @param value the metric value * @param unit the metric unit + * @param dimensions the metric dimensions */ - public void addMetric(String name, Number value, Unit unit) { - addMetric(new Metric(name, value, unit)); + public void addMetric(String name, Number value, Unit unit, Dimension... dimensions) { + addMetric(new Metric(name, value, unit, dimensions)); } /**