Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Glebashnik/fix tensor label mapping memory leak #32831

Draft
wants to merge 19 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 22 additions & 2 deletions vespajlib/abi-spec.json
Original file line number Diff line number Diff line change
Expand Up @@ -932,6 +932,21 @@
],
"fields" : [ ]
},
"com.yahoo.tensor.Label" : {
"superClass" : "java.lang.Object",
"interfaces" : [ ],
"attributes" : [
"public",
"interface",
"abstract"
],
"methods" : [
"public abstract long asNumeric()",
"public abstract java.lang.String asString()",
"public abstract boolean isEqualTo(com.yahoo.tensor.Label)"
],
"fields" : [ ]
},
"com.yahoo.tensor.MappedTensor$Builder" : {
"superClass" : "java.lang.Object",
"interfaces" : [
Expand Down Expand Up @@ -1077,6 +1092,7 @@
],
"methods" : [
"public void <init>(int)",
"public com.yahoo.tensor.PartialAddress$Builder add(java.lang.String, com.yahoo.tensor.Label)",
"public com.yahoo.tensor.PartialAddress$Builder add(java.lang.String, long)",
"public com.yahoo.tensor.PartialAddress$Builder add(java.lang.String, java.lang.String)",
"public com.yahoo.tensor.PartialAddress build()"
Expand All @@ -1091,8 +1107,10 @@
],
"methods" : [
"public java.lang.String dimension(int)",
"public com.yahoo.tensor.Label objectLabel(java.lang.String)",
"public long numericLabel(java.lang.String)",
"public java.lang.String label(java.lang.String)",
"public com.yahoo.tensor.Label objectLabel(int)",
"public java.lang.String label(int)",
"public int size()",
"public com.yahoo.tensor.TensorAddress asAddress(com.yahoo.tensor.TensorType)",
Expand Down Expand Up @@ -1281,6 +1299,7 @@
"public void <init>(com.yahoo.tensor.TensorType)",
"public com.yahoo.tensor.TensorAddress$Builder add(java.lang.String)",
"public com.yahoo.tensor.TensorAddress$Builder add(java.lang.String, java.lang.String)",
"public com.yahoo.tensor.TensorAddress$Builder add(java.lang.String, com.yahoo.tensor.Label)",
"public com.yahoo.tensor.TensorAddress$Builder add(java.lang.String, int)",
"public com.yahoo.tensor.TensorAddress$Builder add(java.lang.String, long)",
"public com.yahoo.tensor.TensorAddress$Builder copy()",
Expand Down Expand Up @@ -1317,6 +1336,7 @@
"public static varargs com.yahoo.tensor.TensorAddress of(long[])",
"public static varargs com.yahoo.tensor.TensorAddress of(int[])",
"public abstract int size()",
"public abstract com.yahoo.tensor.Label objectLabel(int)",
"public abstract java.lang.String label(int)",
"public abstract long numericLabel(int)",
"public abstract com.yahoo.tensor.TensorAddress withLabel(int, long)",
Expand Down Expand Up @@ -3827,14 +3847,14 @@
"public static java.lang.String toMessageString(java.lang.Throwable)",
"public static java.util.Optional findCause(java.lang.Throwable, java.lang.Class)",
"public static void uncheck(com.yahoo.yolean.Exceptions$RunnableThrowingIOException)",
"public static void uncheckInterrupted(com.yahoo.yolean.Exceptions$RunnableThrowingInterruptedException)",
"public static void uncheckInterruptedAndRestoreFlag(com.yahoo.yolean.Exceptions$RunnableThrowingInterruptedException)",
"public static varargs void uncheck(com.yahoo.yolean.Exceptions$RunnableThrowingIOException, java.lang.String, java.lang.String[])",
"public static void uncheckAndIgnore(com.yahoo.yolean.Exceptions$RunnableThrowingIOException, java.lang.Class)",
"public static java.util.function.Function uncheck(com.yahoo.yolean.Exceptions$FunctionThrowingIOException)",
"public static java.lang.Object uncheck(com.yahoo.yolean.Exceptions$SupplierThrowingIOException)",
"public static varargs java.lang.Object uncheck(com.yahoo.yolean.Exceptions$SupplierThrowingIOException, java.lang.String, java.lang.String[])",
"public static java.lang.Object uncheckAndIgnore(com.yahoo.yolean.Exceptions$SupplierThrowingIOException, java.lang.Class)",
"public static void uncheckInterrupted(com.yahoo.yolean.Exceptions$RunnableThrowingInterruptedException)",
"public static void uncheckInterruptedAndRestoreFlag(com.yahoo.yolean.Exceptions$RunnableThrowingInterruptedException)",
"public static java.lang.Object uncheckInterrupted(com.yahoo.yolean.Exceptions$SupplierThrowingInterruptedException)",
"public static java.lang.RuntimeException throwUnchecked(java.lang.Throwable)"
],
Expand Down
16 changes: 16 additions & 0 deletions vespajlib/src/main/java/com/yahoo/tensor/Label.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package com.yahoo.tensor;

/**
* A label for a tensor dimension.
* It handles both mapped dimensions with string labels and indexed dimensions with numeric labels.
* For mapped dimensions, a negative numeric label is assigned by LabelCache.
* For indexed dimension, the index itself is used as a positive numeric label.
*
* @author glebashnik
*/
public interface Label {
long asNumeric();
String asString();
boolean isEqualTo(Label label);
}
65 changes: 43 additions & 22 deletions vespajlib/src/main/java/com/yahoo/tensor/PartialAddress.java
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package com.yahoo.tensor;

import com.yahoo.tensor.impl.Label;
import com.yahoo.tensor.impl.LabelCache;
import com.yahoo.tensor.impl.TensorAddressAny;

/**
* An address to a subset of a tensors' cells, specifying a label for some, but not necessarily all, of the tensors
Expand All @@ -18,7 +19,7 @@ public class PartialAddress {
// Two arrays which contains corresponding dimension:label pairs.
// The sizes of these are always equal.
private final String[] dimensionNames;
private final long[] labels;
private final Label[] labels;

private PartialAddress(Builder builder) {
this.dimensionNames = builder.dimensionNames;
Expand All @@ -31,31 +32,44 @@ public String dimension(int i) {
return dimensionNames[i];
}

/** Returns the numeric label of this dimension, or -1 if no label is specified for it */
public long numericLabel(String dimensionName) {

/** Returns the label object of this dimension, or -1 if no label is specified for it */
public Label objectLabel(String dimensionName) {
for (int i = 0; i < dimensionNames.length; i++)
if (dimensionNames[i].equals(dimensionName))
return labels[i];
return Tensor.invalidIndex;

return LabelCache.INVALID_INDEX_LABEL;
}

/** Returns the numeric label of this dimension, or -1 if no label is specified for it */
public long numericLabel(String dimensionName) {
return objectLabel(dimensionName).asNumeric();
}

/** Returns the label of this dimension, or null if no label is specified for it */
/** Returns the string label of this dimension, or null if no label is specified for it */
public String label(String dimensionName) {
for (int i = 0; i < dimensionNames.length; i++)
if (dimensionNames[i].equals(dimensionName))
return Label.fromNumber(labels[i]);
return null;
return objectLabel(dimensionName).asString();
}

/**
* Returns the label at position i
* Returns label object at position i
*
* @throws IllegalArgumentException if i is out of bounds
*/
public String label(int i) {
public Label objectLabel(int i) {
if (i >= size())
throw new IllegalArgumentException("No label at position " + i + " in " + this);
return Label.fromNumber(labels[i]);
return labels[i];
}

/**
* Returns string label at position i
*
* @throws IllegalArgumentException if i is out of bounds
*/
public String label(int i) {
return objectLabel(i).asString();
}

public int size() { return dimensionNames.length; }
Expand All @@ -65,14 +79,14 @@ public String label(int i) {
public TensorAddress asAddress(TensorType type) {
if (type.rank() != size())
throw new IllegalArgumentException(type + " has a different rank than " + this);
long[] numericLabels = new long[labels.length];
Label[] labels = new Label[this.labels.length];
for (int i = 0; i < type.dimensions().size(); i++) {
long label = numericLabel(type.dimensions().get(i).name());
if (label == Tensor.invalidIndex)
Label label = objectLabel(type.dimensions().get(i).name());
if (label.isEqualTo(LabelCache.INVALID_INDEX_LABEL))
throw new IllegalArgumentException(type + " dimension names does not match " + this);
numericLabels[i] = label;
labels[i] = label;
}
return TensorAddress.of(numericLabels);
return TensorAddressAny.ofUnsafe(labels);
}

@Override
Expand All @@ -88,24 +102,31 @@ public String toString() {
public static class Builder {

private String[] dimensionNames;
private long[] labels;
private Label[] labels;
private int index = 0;

public Builder(int size) {
dimensionNames = new String[size];
labels = new long[size];
labels = new Label[size];
}

public Builder add(String dimensionName, long label) {
public Builder add(String dimensionName, Label label) {
dimensionNames[index] = dimensionName;
labels[index] = label;
index++;
return this;
}

public Builder add(String dimensionName, long label) {
dimensionNames[index] = dimensionName;
labels[index] = LabelCache.GLOBAL.getOrCreateLabel(label);
index++;
return this;
}

public Builder add(String dimensionName, String label) {
dimensionNames[index] = dimensionName;
labels[index] = Label.toNumber(label);
labels[index] = LabelCache.GLOBAL.getOrCreateLabel(label);
index++;
return this;
}
Expand Down
4 changes: 2 additions & 2 deletions vespajlib/src/main/java/com/yahoo/tensor/Tensor.java
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
import com.yahoo.tensor.functions.Softmax;
import com.yahoo.tensor.functions.XwPlusB;
import com.yahoo.tensor.functions.Expand;
import com.yahoo.tensor.impl.Label;

import java.util.ArrayList;
import java.util.Iterator;
Expand Down Expand Up @@ -626,7 +625,8 @@ public CellBuilder label(String dimension, String label) {
public TensorType type() { return tensorBuilder.type(); }

public CellBuilder label(String dimension, long label) {
return label(dimension, Label.fromNumber(label));
addressBuilder.add(dimension, label);
return this;
}

public Builder value(double cellValue) {
Expand Down
Loading