Skip to content

Commit

Permalink
Fix, add tests, and pass CI
Browse files Browse the repository at this point in the history
  • Loading branch information
zachgk committed Sep 28, 2023
1 parent 8522be0 commit 61597e5
Show file tree
Hide file tree
Showing 6 changed files with 40 additions and 57 deletions.
15 changes: 15 additions & 0 deletions api/src/main/java/ai/djl/nn/Block.java
Original file line number Diff line number Diff line change
Expand Up @@ -313,6 +313,21 @@ default void freezeParameters(boolean freeze) {
}
}

/**
* Freezes or unfreezes all parameters inside the block that pass the predicate.
*
* @param freeze true to mark as frozen rather than unfrozen
* @param pred true tests if the parameter should be updated
* @see Parameter#freeze(boolean)
*/
default void freezeParameters(boolean freeze, Predicate<Parameter> pred) {
for (Parameter parameter : getParameters().values()) {
if (pred.test(parameter)) {
parameter.freeze(freeze);
}
}
}

/**
* Validates that actual layout matches the expected layout.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import ai.djl.Model;
import ai.djl.ndarray.types.DataType;
import ai.djl.nn.Parameter;
import ai.djl.nn.Parameter.Type;
import ai.djl.pytorch.jni.JniUtils;
import ai.djl.training.Trainer;
import ai.djl.training.TrainingConfig;
Expand Down Expand Up @@ -189,7 +190,9 @@ public Trainer newTrainer(TrainingConfig trainingConfig) {
}
if (wasLoaded) {
// Unfreeze parameters if training directly
block.freezeParameters(false);
block.freezeParameters(
false,
p -> p.getType() != Type.RUNNING_MEAN && p.getType() != Type.RUNNING_VAR);
}
for (Pair<Initializer, Predicate<Parameter>> pair : initializer) {
if (pair.getKey() != null && pair.getValue() != null) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
package ai.djl.pytorch.engine;

import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDArrays;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.NDManager;
import ai.djl.ndarray.NDUtils;
Expand All @@ -25,6 +26,7 @@
import ai.djl.pytorch.jni.JniUtils;

import java.util.Arrays;
import java.util.Comparator;
import java.util.List;

/** {@code PtNDArrayEx} is the PyTorch implementation of the {@link NDArrayEx}. */
Expand Down Expand Up @@ -763,11 +765,9 @@ public NDList multiBoxDetection(
int nmsTopK) {
assert (inputs.size() == 3);

Shape ashape = inputs.get(2).getShape();

NDArray clsProb = inputs.get(0);
NDArray locPred = inputs.get(1);
NDArray anchors = inputs.get(2).reshape(new Shape(ashape.get(1), 4));
NDArray anchors = inputs.get(2).reshape(new Shape(-1, 4));

NDManager ndManager = array.getManager();

Expand All @@ -781,11 +781,12 @@ public NDList multiBoxDetection(
final float[] pAnchor = anchors.toFloatArray();

// [id, prob, xmin, ymin, xmax, ymax]
float[][] outputs = new float[numAnchors][6];
// TODO Move to NDArray-based implementation
NDList batchOutputs = new NDList();
for (int nbatch = 0; nbatch < numBatches; ++nbatch) {
final float[] pClsProb =
clsProb.get((long) nbatch * numClasses * numAnchors).toFloatArray();
final float[] pLocPred = locPred.get((long) nbatch * numAnchors * 4).toFloatArray();
float[][] outputs = new float[numAnchors][6];
final float[] pClsProb = clsProb.get(nbatch).toFloatArray();
final float[] pLocPred = locPred.get(nbatch).toFloatArray();

for (int i = 0; i < numAnchors; ++i) {
// find the predicted class id and probability
Expand Down Expand Up @@ -832,50 +833,15 @@ public NDList multiBoxDetection(
outputs[i][5] = outRowLast4[3];
}

int validCount = 0;
for (int i = 0; i < numAnchors; ++i) {
int offset1 = validCount;
if (outputs[i][0] >= 0) {
outputs[offset1][0] = outputs[i][0];
outputs[offset1][1] = outputs[i][1];
outputs[offset1][2] = outputs[i][2];
outputs[offset1][3] = outputs[i][3];
outputs[offset1][4] = outputs[i][4];
outputs[offset1][5] = outputs[i][5];
++validCount;
}
}

if (validCount < 1) continue;

float[][] sorter;
sorter = new float[validCount][2];
for (int i = 0; i < validCount; ++i) {
sorter[i][0] = outputs[i][1];
sorter[i][1] = i;
}
Arrays.sort(sorter, (a, b) -> Double.compare(a[0], b[0]) * -1);

// re-order output
float[][] ptemp = new float[outputs.length][6];
for (int i = 0; i < outputs.length; i++) {
for (int j = 0; j < 6; j++) {
ptemp[i][j] = outputs[(int) sorter[i][1]][j];
}
}
int nkeep = sorter.length;

for (int i = 0; i < nkeep; ++i) {
for (int j = 0; j < 6; ++j) {
outputs[i][j] = ptemp[i][j];
}
}
outputs =
Arrays.stream(outputs)
.filter(o -> o[0] >= 0)
.sorted(Comparator.comparing(o -> -o[1]))
.toArray(float[][]::new);

// apply nms
for (int i = 0; i < nkeep; ++i) {
if (outputs[i][0] < 0) continue; // skip eliminated
for (int j = i + 1; j < nkeep; ++j) {
if (outputs[j][0] < 0) continue; // skip eliminated
for (int i = 0; i < outputs.length; ++i) {
for (int j = i + 1; j < outputs.length; ++j) {
if (outputs[i][0] == outputs[j][0]) {
float[] outputsIRow4 = new float[4];
float[] outputsJRow4 = new float[4];
Expand All @@ -894,9 +860,10 @@ public NDList multiBoxDetection(
}
}
}
batchOutputs.add(ndManager.create(outputs));
} // end iter batch

NDArray pOutNDArray = ndManager.create(outputs).reshape(1, 4, 6);
NDArray pOutNDArray = NDArrays.stack(batchOutputs);
NDList resultNDList = new NDList();
resultNDList.add(pOutNDArray);
assert (resultNDList.size() == 1);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@ public class TrainPikachuTest {

@Test
public void testDetection() throws IOException, MalformedModelException, TranslateException {
TestRequirements.engine("MXNet");
TestRequirements.nightly();

String[] args;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
import ai.djl.nn.LambdaBlock;
import ai.djl.nn.SequentialBlock;
import ai.djl.repository.zoo.Criteria;
import ai.djl.repository.zoo.ModelZoo;
import ai.djl.repository.zoo.ZooModel;
import ai.djl.training.DefaultTrainingConfig;
import ai.djl.training.EasyTrain;
Expand Down Expand Up @@ -123,10 +124,8 @@ private TrainingConfig setupTrainingConfig() {
}

private ZooModel<Image, DetectedObjects> getModel() throws IOException, ModelException {
// SSD-pikachu model only available in MXNet
// TODO: Add PyTorch model to model zoo
TestUtils.requiresEngine("MXNet");

TestUtils.requiresEngine(
ModelZoo.getModelZoo("ai.djl.zoo").getSupportedEngines().toArray(String[]::new));
Criteria<Image, DetectedObjects> criteria =
Criteria.builder()
.optApplication(Application.CV.OBJECT_DETECTION)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,8 @@ public String getGroupId() {
public Set<String> getSupportedEngines() {
Set<String> set = new HashSet<>();
set.add("MXNet");
set.add("PyTorch");
// TODO Currently WIP in supporting these two engines in the basic model zoo
// set.add("PyTorch");
// set.add("TensorFlow");
return set;
}
Expand Down

0 comments on commit 61597e5

Please sign in to comment.