Skip to content

Commit

Permalink
Merge branch 'master' into gather_dev
Browse files Browse the repository at this point in the history
  • Loading branch information
KexinFeng committed May 6, 2022
2 parents 7bd40e5 + 05a685e commit b887bc1
Show file tree
Hide file tree
Showing 36 changed files with 825 additions and 154 deletions.
2 changes: 0 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,6 @@
![DeepJavaLibrary](website/img/deepjavalibrary.png?raw=true "Deep Java Library")

![Continuous](https://github.com/deepjavalibrary/djl/workflows/Continuous/badge.svg)
![Continuous PyTorch](https://github.com/deepjavalibrary/djl/workflows/Continous%20PyTorch/badge.svg)
![Continuous Tensorflow](https://github.com/deepjavalibrary/djl/workflows/Continuous%20Tensorflow/badge.svg)
![Docs](https://github.com/deepjavalibrary/djl/workflows/Docs/badge.svg)
![Nightly Publish](https://github.com/deepjavalibrary/djl/workflows/Nightly%20Publish/badge.svg)

Expand Down
15 changes: 8 additions & 7 deletions api/src/main/java/ai/djl/inference/Predictor.java
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import ai.djl.Device;
import ai.djl.Model;
import ai.djl.metric.Metrics;
import ai.djl.metric.Unit;
import ai.djl.ndarray.LazyNDArray;
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDList;
Expand Down Expand Up @@ -228,29 +229,29 @@ private void preprocessEnd(NDList list) {
if (metrics != null) {
waitToRead(list);
long tmp = System.nanoTime();
long duration = tmp - timestamp;
long duration = (tmp - timestamp) / 1000;
timestamp = tmp;
metrics.addMetric("Preprocess", duration, "nano");
metrics.addMetric("Preprocess", duration, Unit.MICROSECONDS);
}
}

private void predictEnd(NDList list) {
if (metrics != null) {
waitToRead(list);
long tmp = System.nanoTime();
long duration = tmp - timestamp;
long duration = (tmp - timestamp) / 1000;
timestamp = tmp;
metrics.addMetric("Inference", duration, "nano");
metrics.addMetric("Inference", duration, Unit.MICROSECONDS);
}
}

private void postProcessEnd(long begin) {
if (metrics != null) {
long tmp = System.nanoTime();
long duration = tmp - timestamp;
long duration = (tmp - timestamp) / 1000;
timestamp = tmp;
metrics.addMetric("Postprocess", duration, "nano");
metrics.addMetric("Total", tmp - begin, "nano");
metrics.addMetric("Postprocess", duration, Unit.MICROSECONDS);
metrics.addMetric("Total", (tmp - begin) / 1000, Unit.MICROSECONDS);
}
}

Expand Down
57 changes: 57 additions & 0 deletions api/src/main/java/ai/djl/metric/Dimension.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
/*
* Copyright 2022 Amazon.com, Inc. or its affiliates. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance
* with the License. A copy of the License is located at
*
* http://aws.amazon.com/apache2.0/
*
* or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES
* OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions
* and limitations under the License.
*/
package ai.djl.metric;

import com.google.gson.annotations.SerializedName;

/** A class represents a metric dimension. */
public class Dimension {

@SerializedName("Name")
private String name;

@SerializedName("Value")
private String value;

/** Constructs a new {@code Dimension} instance. */
public Dimension() {}

/**
* Constructs a new {@code Dimension} instance.
*
* @param name the dimension name
* @param value the dimension value
*/
public Dimension(String name, String value) {
this.name = name;
this.value = value;
}

/**
* Returns the dimension name.
*
* @return the dimension name
*/
public String getName() {
return name;
}

/**
* Returns the dimension value.
*
* @return the dimension value
*/
public String getValue() {
return value;
}
}
151 changes: 137 additions & 14 deletions api/src/main/java/ai/djl/metric/Metric.java
Original file line number Diff line number Diff line change
Expand Up @@ -12,17 +12,39 @@
*/
package ai.djl.metric;

import com.google.gson.annotations.SerializedName;
import java.net.InetAddress;
import java.net.UnknownHostException;
import java.util.regex.Matcher;
import java.util.regex.Pattern;

/**
* A class representing a single recorded {@code Metric} value.
*
* @see Metrics
*/
public class Metric {

private static final Pattern PATTERN =
Pattern.compile(
"\\s*([\\w\\s]+)\\.([\\w\\s]+):([0-9\\-,.e]+)(?>\\|#([^|]*))?(?>\\|(\\d+))?");

private static final Dimension HOST = new Dimension("Host", getLocalHostName());

@SerializedName("MetricName")
private String metricName;
private Number value;

@SerializedName("Value")
private String value;

@SerializedName("Unit")
private String unit;
private long timestamp;

@SerializedName("Dimensions")
private Dimension[] dimensions;

@SerializedName("Timestamp")
private String timestamp;

/**
* Constructs a {@code Metric} instance with the specified {@code metricName} and <code>
Expand All @@ -32,7 +54,7 @@ public class Metric {
* @param value the metric value
*/
public Metric(String metricName, Number value) {
this(metricName, value, "count");
this(metricName, value, Unit.COUNT);
}

/**
Expand All @@ -43,11 +65,43 @@ public Metric(String metricName, Number value) {
* @param value the metric value
* @param unit the metric unit
*/
public Metric(String metricName, Number value, String unit) {
public Metric(String metricName, Number value, Unit unit) {
this(metricName, value.toString(), unit.getValue(), null, HOST);
}

/**
* Constructs a {@code Metric} instance with the specified {@code metricName}, <code>value
* </code>, and {@code unit}.
*
* @param metricName the metric name
* @param value the metric value
* @param unit the metric unit
* @param dimensions the metric dimensions
*/
public Metric(String metricName, Number value, Unit unit, Dimension... dimensions) {
this(metricName, value.toString(), unit.getValue(), null, dimensions);
}

/**
* Constructs a new {@code Metric} instance.
*
* @param metricName the metric name
* @param value the metric value
* @param unit the metric unit
* @param timestamp the metric timestamp
* @param dimensions the metric dimensions
*/
private Metric(
String metricName,
String value,
String unit,
String timestamp,
Dimension... dimensions) {
this.metricName = metricName;
this.value = value;
this.unit = unit;
timestamp = System.currentTimeMillis();
this.value = value;
this.timestamp = timestamp;
this.dimensions = dimensions;
}

/**
Expand All @@ -60,35 +114,104 @@ public String getMetricName() {
}

/**
* Returns the value of the {@code Metric}.
* Returns the int value of the {@code Metric}.
*
* @return the metric value
* @return the metric value in int
*/
public Number getValue() {
return value;
public Double getValue() {
return Double.valueOf(value);
}

/**
* Returns the unit of the {@code Metric}.
*
* @return the metric unit
*/
public String getUnit() {
return unit;
public Unit getUnit() {
return Unit.fromValue(unit);
}

/**
* Returns the timestamp of the {@code Metric}.
*
* @return the metric timestamp
*/
public long getTimestamp() {
public String getTimestamp() {
return timestamp;
}

/**
* Returns the metric dimensions.
*
* @return the metric dimensions
*/
public Dimension[] getDimensions() {
return dimensions;
}

/**
* Returns a {@code Metric} instance parsed from the log string.
*
* @param line the input string
* @return a {@code Metric} object
*/
public static Metric parse(String line) {
// DiskAvailable.Gigabytes:311|#Host:localhost|1650953744320
Matcher matcher = PATTERN.matcher(line);
if (!matcher.matches()) {
return null;
}

String metricName = matcher.group(1);
String unit = matcher.group(2);
String value = matcher.group(3);
String dimension = matcher.group(4);
String timestamp = matcher.group(5);

Dimension[] dimensions = null;
if (dimension != null) {
String[] dims = dimension.split(",");
dimensions = new Dimension[dims.length];
int index = 0;
for (String dim : dims) {
String[] pair = dim.split(":");
if (pair.length == 2) {
dimensions[index++] = new Dimension(pair[0], pair[1]);
}
}
}

return new Metric(metricName, value, unit, timestamp, dimensions);
}

/** {@inheritDoc} */
@Override
public String toString() {
return metricName + '.' + unit + ':' + value + "|#timestamp:" + timestamp;
StringBuilder sb = new StringBuilder(128);
sb.append(metricName).append('.').append(unit).append(':').append(value);
if (dimensions != null) {
boolean first = true;
for (Dimension dimension : dimensions) {
if (first) {
sb.append("|#");
first = false;
} else {
sb.append(',');
}
sb.append(dimension.getName()).append(':').append(dimension.getValue());
}
}
if (timestamp != null) {
sb.append('|').append(timestamp);
}
return sb.toString();
}

private static String getLocalHostName() {
try {
return InetAddress.getLocalHost().getHostName();
} catch (UnknownHostException e) {
return "Unknown";
}
}
}
26 changes: 3 additions & 23 deletions api/src/main/java/ai/djl/metric/Metrics.java
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
*/
package ai.djl.metric;

import java.io.Serializable;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Comparator;
Expand Down Expand Up @@ -41,8 +40,6 @@
*/
public class Metrics {

private static final MetricValueComparator VALUE_COMPARATOR = new MetricValueComparator();

private Map<String, List<Metric>> metrics;

/** Constructs an empty {@code Metrics} instance. */
Expand Down Expand Up @@ -80,7 +77,7 @@ public void addMetric(String name, Number value) {
* @param value the metric value
* @param unit the metric unit
*/
public void addMetric(String name, Number value, String unit) {
public void addMetric(String name, Number value, Unit unit) {
addMetric(new Metric(name, value, unit));
}

Expand Down Expand Up @@ -146,7 +143,7 @@ public Metric percentile(String metricName, int percentile) {
}

List<Metric> list = new ArrayList<>(metric);
list.sort(VALUE_COMPARATOR);
list.sort(Comparator.comparingDouble(Metric::getValue));
int index = metric.size() * percentile / 100;
return list.get(index);
}
Expand All @@ -163,23 +160,6 @@ public double mean(String metricName) {
throw new IllegalArgumentException("Metric name not found: " + metricName);
}

return metric.stream().collect(Collectors.averagingDouble(m -> m.getValue().doubleValue()));
}

/** Comparator based on {@code Metric}'s value field. */
private static final class MetricValueComparator implements Comparator<Metric>, Serializable {

private static final long serialVersionUID = 1L;

/** {@inheritDoc} */
@Override
public int compare(Metric o1, Metric o2) {
Number n1 = o1.getValue();
Number n2 = o2.getValue();
if (n1 instanceof Double || n1 instanceof Float) {
return Double.compare(n1.doubleValue(), n2.doubleValue());
}
return Long.compare(n1.longValue(), n2.longValue());
}
return metric.stream().collect(Collectors.averagingDouble(Metric::getValue));
}
}
Loading

0 comments on commit b887bc1

Please sign in to comment.