Skip to content

Commit

Permalink
[api] Enhancement features for LMSearch (#2642)
Browse files Browse the repository at this point in the history
  • Loading branch information
KexinFeng authored Jun 9, 2023
1 parent dd9a111 commit 9d7737c
Show file tree
Hide file tree
Showing 12 changed files with 113 additions and 12 deletions.
7 changes: 5 additions & 2 deletions api/src/main/java/ai/djl/ndarray/BaseNDManager.java
Original file line number Diff line number Diff line change
Expand Up @@ -303,11 +303,14 @@ public String toString() {

/** {@inheritDoc} */
@Override
public synchronized void attachInternal(String resourceId, AutoCloseable resource) {
public synchronized void attachInternal(String resourceId, AutoCloseable... resources) {
if (capped.get()) {
throw new IllegalStateException("NDManager is capped for addition of resources.");
}
attachUncappedInternal(resourceId, resource);
for (int i = 0; i < resources.length; i++) {
attachUncappedInternal(
resources.length == 1 ? resourceId : resourceId + "_" + i, resources[i]);
}
}

/** {@inheritDoc} */
Expand Down
3 changes: 2 additions & 1 deletion api/src/main/java/ai/djl/ndarray/NDArray.java
Original file line number Diff line number Diff line change
Expand Up @@ -4206,7 +4206,8 @@ default NDArray argSort(int axis) {
* jshell&gt; NDArray array = manager.create(new float[] {0f, 1f, 2f, 3f}, new Shape(2, 2));
* jshell&gt; array.repeat(1, 2);
* ND: (6) cpu() float32
* [0., 0., 1., 1., 2., 2.]
* [[0., 0., 1., 1.],
* [2., 2., 3., 3.]]
* </pre>
*
* @param axis the axis to repeat
Expand Down
14 changes: 13 additions & 1 deletion api/src/main/java/ai/djl/ndarray/NDList.java
Original file line number Diff line number Diff line change
Expand Up @@ -270,7 +270,19 @@ public NDList addAll(NDList other) {
* @return a view of the portion of this NDList
*/
public NDList subNDList(int fromIndex) {
return new NDList(subList(fromIndex, size()));
return subNDList(fromIndex, size());
}

/**
* Returns a view of the portion of this NDList between the specified fromIndex, inclusive, and
* toIndex, exclusive.
*
* @param fromIndex the start index (inclusive)
* @param toIndex the end index (exclusive)
* @return a view of the portion of this NDList
*/
public NDList subNDList(int fromIndex, int toIndex) {
return new NDList(subList(fromIndex, toIndex));
}

/**
Expand Down
2 changes: 1 addition & 1 deletion api/src/main/java/ai/djl/ndarray/NDManager.java
Original file line number Diff line number Diff line change
Expand Up @@ -1589,7 +1589,7 @@ default NDArray hanningWindow(long numPoints) {
* @param resourceId the unique resourceId
* @param resource the {@link AutoCloseable} resource to be attached
*/
void attachInternal(String resourceId, AutoCloseable resource);
void attachInternal(String resourceId, AutoCloseable... resource);

/**
* Attaches a resource to this {@code NDManager} circumventing any cap protection.
Expand Down
20 changes: 20 additions & 0 deletions api/src/main/java/ai/djl/ndarray/NDScope.java
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,26 @@ public static void unregister(NDArray array) {
queue.getLast().resources.remove(array);
}

/**
* Unregisters {@link NDArray} object from this scope.
*
* @param arrays the array of {@link NDArray} object
*/
public static void unregister(NDArray... arrays) {
for (NDArray array : arrays) {
unregister(array);
}
}

/**
* Unregisters {@link NDArray} object from this scope.
*
* @param ndlist the {@link NDList} object
*/
public static void unregister(NDList ndlist) {
ndlist.forEach(NDScope::unregister);
}

/** {@inheritDoc} */
@Override
public void close() {
Expand Down
10 changes: 10 additions & 0 deletions api/src/main/java/ai/djl/ndarray/index/NDIndex.java
Original file line number Diff line number Diff line change
Expand Up @@ -258,6 +258,16 @@ public NDIndex addBooleanIndex(NDArray index) {
return this;
}

/**
* Appends ellipse index in the current dimension.
*
* @return the updated {@link NDIndex}
*/
public NDIndex addEllipseDim() {
ellipsisIndex = indices.size();
return this;
}

/**
* Appends a new index to get all values in the dimension.
*
Expand Down
9 changes: 9 additions & 0 deletions api/src/main/java/ai/djl/ndarray/types/Shape.java
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,15 @@ public long get(int dimension) {
return shape[dimension];
}

/**
* Returns the last index.
*
* @return the last index
*/
public long getLastDimension() {
return shape[shape.length - 1];
}

/**
* Returns the layout type in the given dimension.
*
Expand Down
2 changes: 1 addition & 1 deletion api/src/main/java/ai/djl/nn/AbstractBaseBlock.java
Original file line number Diff line number Diff line change
Expand Up @@ -72,8 +72,8 @@ public final NDList forward(
NDList inputs,
boolean training,
PairList<String, Object> params) {
NDManager paramsManager = parameterStore.getManager();
if (training && !isInitialized()) {
NDManager paramsManager = parameterStore.getManager();
initialize(paramsManager, DataType.FLOAT32, inputs.getShapes());
}
return forwardInternal(parameterStore, inputs, training, params);
Expand Down
3 changes: 2 additions & 1 deletion api/src/main/java/ai/djl/repository/zoo/Criteria.java
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,8 @@ public ZooModel<I, O> loadModel()
}
}
throw new ModelNotFoundException(
"No matching model with specified Input/Output type found.", lastException);
"No model with the specified URI or the matching Input/Output type is found.",
lastException);
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -293,7 +293,7 @@ public List<NDArray> getManagedArrays() {

/** {@inheritDoc} */
@Override
public void attachInternal(String resourceId, AutoCloseable resource) {}
public void attachInternal(String resourceId, AutoCloseable... resource) {}

/** {@inheritDoc} */
@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ public void testPassthrough() {
Assert.assertEquals(manager.getName(), "PassthroughNDManager");
Assert.assertTrue(manager.isOpen());
Assert.assertNotNull(manager.getParentManager());
manager.attachInternal(null, null);
manager.attachInternal(null, (AutoCloseable) null);
manager.attachUncappedInternal(null, null);
manager.tempAttachInternal(null, null, null);
manager.detachInternal(null);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ public static NDList forward(PtSymbolBlock block, NDList inputs, boolean isTrain
* @param inputs the input {@link IValue}
* @return the result {@link IValue}
*/
public static IValue forward(PtSymbolBlock block, IValue... inputs) {
public static IValue forward(PtSymbolBlock block, IValue[] inputs) {
return runMethod(block, "forward", inputs);
}

Expand All @@ -79,9 +79,10 @@ public static IValue forward(PtSymbolBlock block, IValue... inputs) {
* @return the result {@link IValue}
*/
public static IValue runMethod(PtSymbolBlock block, String methodName, IValue... inputs) {
long[] handles = Arrays.stream(inputs).mapToLong(IValue::getHandle).toArray();
long[] iValueHandles = Arrays.stream(inputs).mapToLong(IValue::getHandle).toArray();
return new IValue(
PyTorchLibrary.LIB.moduleRunMethod(block.getHandle(), methodName, handles, false));
PyTorchLibrary.LIB.moduleRunMethod(
block.getHandle(), methodName, iValueHandles, false));
}

private static int addToMap(
Expand Down Expand Up @@ -146,4 +147,48 @@ static Pair<IValue[], String> getInputs(NDList ndList) {
}
return new Pair<>(ret, methodName);
}

/**
* Converts ndList to IValue.
*
* @param ndList the NDList to convert
* @param dims the shape of the output
* @return the result {@link IValue}
*/
public static IValue toTupleIValue(NDList ndList, long[] dims) {
return toTupleIValueRecur(ndList, dims, 0, 0).getKey();
}

/**
* Helper function.
*
* @param ndList the NDList to convert
* @param dims the shape of the output
* @param startCount the start index of the current recursion level
* @param level the recursion level
* @return the result
*/
private static Pair<IValue, Integer> toTupleIValueRecur(
NDList ndList, long[] dims, int startCount, int level) {
if (startCount > ndList.size()) {
throw new IllegalArgumentException("startCount illegal");
}
if (dims.length - 1 == level) {
long dim = dims[level];
List<PtNDArray> vector = new ArrayList<>();
for (int i = startCount; i < startCount + dim; i++) {
vector.add((PtNDArray) ndList.get(i));
}
IValue[] output = vector.stream().map(IValue::from).toArray(IValue[]::new);
return new Pair<>(IValue.tupleFrom(output), Math.toIntExact((startCount + dim)));
}

IValue[] output = new IValue[Math.toIntExact(dims[0])];
for (int j = 0; j < dims[level]; j++) {
Pair<IValue, Integer> p = toTupleIValueRecur(ndList, dims, startCount, level + 1);
startCount = p.getValue();
output[j] = p.getKey();
}
return new Pair<>(IValue.tupleFrom(output), startCount);
}
}

0 comments on commit 9d7737c

Please sign in to comment.