Skip to content

Commit

Permalink
FIX: solve problems before merge
Browse files Browse the repository at this point in the history
  • Loading branch information
cspchen committed Aug 31, 2021
1 parent bfc1f5d commit 720a14d
Show file tree
Hide file tree
Showing 15 changed files with 62 additions and 289 deletions.
22 changes: 2 additions & 20 deletions ci/docker/runtime_functions.sh
Original file line number Diff line number Diff line change
Expand Up @@ -321,11 +321,8 @@ build_ubuntu_gpu_and_test() {
}

java_package_integration_test() {
# install gradle
add-apt-repository ppa:cwchien/gradle
apt-get update -y
apt-get install gradle -y
# build java prokect
# make sure you are using java 11
# build java project
cd /work/mxnet/java-package
./gradlew build -x javadoc
# generate native library
Expand Down Expand Up @@ -1429,21 +1426,6 @@ test_artifact_repository() {
popd
}

integration_test() {
# install gradle
add-apt-repository ppa:cwchien/gradle
apt-get update -y
apt-get install gradle -y
# build java prokect
cd /work/mxnet/java-package
./gradle build -x javadoc
# generate native library
./gradlew :native:buildLocalLibraryJarDefault
./gradlew :native:mkl-linuxJar
# run integration
./gradlew :integration:run
}

##############################################################
# MAIN
#
Expand Down
44 changes: 0 additions & 44 deletions docker/dev/docker_run.sh

This file was deleted.

17 changes: 17 additions & 0 deletions java-package/gradle/wrapper/gradle-wrapper.properties
Original file line number Diff line number Diff line change
@@ -1,3 +1,20 @@
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

distributionBase=GRADLE_USER_HOME
distributionPath=wrapper/dists
distributionUrl=https\://services.gradle.org/distributions/gradle-7.0-bin.zip
Expand Down
13 changes: 7 additions & 6 deletions java-package/gradlew
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
#!/usr/bin/env sh

#
# Copyright 2015 the original author or authors.
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
Expand Down
3 changes: 0 additions & 3 deletions java-package/gradlew.bat
Original file line number Diff line number Diff line change
@@ -1,6 +1,3 @@
@rem
@rem Copyright 2015 the original author or authors.
@rem
@rem Licensed under the Apache License, Version 2.0 (the "License");
@rem you may not use this file except in compliance with the License.
@rem You may obtain a copy of the License at
Expand Down
2 changes: 0 additions & 2 deletions java-package/integration/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -48,9 +48,7 @@ run {
systemProperties System.getProperties()
systemProperties.remove("user.dir")
systemProperty("file.encoding", "UTF-8")
// systemProperty "java.library.path", "/Users/cspchen/Work/incubator-mxnet/build"
jvmArgs "-Xverify:none"
args("-p=org.apache.mxnet.integration.tests.", "-m=modelLoadAndPredictTest")
}

checkstyleMain {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,6 @@ public class ModelTest {
public void modelLoadAndPredictTest() {
try (MxResource base = BaseMxResource.getSystemMxResource()) {
Model model = Model.loadModel(Item.MLP);
// Model model = Model.loadModel("test",
// Paths.get("/xxx/xxx/mxnet.java_package/cache/repo/test-models/mlp.tar.gz/mlp/"));
Predictor<NDList, NDList> predictor = model.newPredictor();
NDArray input = NDArray.create(base, new Shape(1, 28, 28)).ones();
NDList inputs = new NDList();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

import java.io.IOException;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
Expand Down Expand Up @@ -144,33 +145,39 @@ public void createNdArray() {
throw e;
}
BaseMxResource base = BaseMxResource.getSystemMxResource();
Assert.assertEquals(base.getSubResource().size(), 0);
int countNotReleased = 0;
for (MxResource mxResource : base.getSubResource().values()) {
if (!mxResource.getClosed()) {
++countNotReleased;
}
}
Assert.assertEquals(countNotReleased, 0);
} catch (ClassCastException e) {
logger.error(e.getMessage());
throw e;
}
}

// @Test
// public void loadNdArray() {
//
// try (BaseMxResource base = BaseMxResource.getSystemMxResource()) {
// NDList mxNDArray =
// JnaUtils.loadNdArray(
// base,
// Paths.get(
//
// "/Users/cspchen/Downloads/mxnet_resnet18/resnet18_v1-0000.params"),
// Device.defaultIfNull(null));
// logger.info(mxNDArray.toString());
// logger.info(
// String.format(
// "The amount of sub resources managed by BaseMxResource: %s",
// base.getSubResource().size()));
// }
// logger.info(
// String.format(
// "The amount of sub resources managed by BaseMxResource: %s",
// BaseMxResource.getSystemMxResource().getSubResource().size()));
// }
@Test
public void loadNdArray() throws IOException {
try (BaseMxResource base = BaseMxResource.getSystemMxResource()) {
Path modelPath = Repository.initRepository(Item.MLP);
Path paramsPath = modelPath.resolve("mlp-0000.params");
NDList mxNDArray =
JnaUtils.loadNdArray(
base, Paths.get(paramsPath.toUri()), Device.defaultIfNull(null));
logger.info(mxNDArray.toString());
logger.info(
String.format(
"The amount of sub resources managed by BaseMxResource: %s",
base.getSubResource().size()));
} catch (IOException e) {
logger.error(e.getMessage());
throw e;
}
logger.info(
String.format(
"The amount of sub resources managed by BaseMxResource: %s",
BaseMxResource.getSystemMxResource().getSubResource().size()));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@
import java.util.stream.IntStream;
import org.apache.mxnet.engine.BaseMxResource;
import org.apache.mxnet.engine.Device;
import org.apache.mxnet.engine.GradReq;
import org.apache.mxnet.engine.MxResource;
import org.apache.mxnet.engine.OpParams;
import org.apache.mxnet.jna.JnaUtils;
Expand Down Expand Up @@ -696,26 +695,6 @@ public NDArray toType(DataType dataType, boolean copy) {
return duplicate(getShape(), dataType, getDevice(), getName());
}

/**
* Attaches a gradient {@code NDArray} to this {@code NDArray} and marks it. It is related to
* training so will not be used here.
*
* @param requiresGrad if {@code NDArray} requires gradient or not
*/
public void setRequiresGradient(boolean requiresGrad) {
if ((requiresGrad && hasGradient()) || (!requiresGrad && !hasGradient())) {
return;
}
NDArray grad = hasGradient() ? getGradient() : createGradient(getSparseFormat());
// DJL go with write as only MXNet support GradReq
int gradReqValue = requiresGrad ? GradReq.WRITE.getValue() : GradReq.NULL.getValue();
IntBuffer gradReqBuffer = IntBuffer.allocate(1);
gradReqBuffer.put(0, gradReqValue);
JnaUtils.autogradMarkVariables(1, getHandle(), gradReqBuffer, grad.getHandle());
hasGradient = requiresGrad;
grad.close();
}

/**
* Creates an instance of {@link NDArray} with specified {@link Shape} filled with zeros.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,7 @@ public enum SparseFormat {
// the dense format is accelerated by MKLDNN by default
DENSE("default", 0),
ROW_SPARSE("row_sparse", 1),
CSR("csr", 2),
COO("coo", 3);
CSR("csr", 2);

private String type;
private int value;
Expand Down

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -53,15 +53,13 @@ public class Parameter extends MxResource {
private Shape shape;
private Type type;
private NDArray array;
private boolean requiresGrad;

Parameter(Builder builder) {
this.id = UUID.randomUUID().toString();
this.name = builder.name;
this.shape = builder.shape;
this.type = builder.type;
this.array = builder.array;
this.requiresGrad = builder.requiresGrad;
}

/**
Expand Down Expand Up @@ -129,15 +127,6 @@ public NDArray getArray() {
return array;
}

/**
* Returns whether this parameter needs gradients to be computed.
*
* @return whether this parameter needs gradients to be computed
*/
public boolean requiresGradient() {
return requiresGrad;
}

/**
* Checks if this {@code Parameter} is initialized.
*
Expand All @@ -156,9 +145,6 @@ public boolean isInitialized() {
*/
public void initialize(MxResource parent, DataType dataType, Device device) {
Objects.requireNonNull(shape, "No parameter shape has been set");
if (requiresGradient()) {
array.setRequiresGradient(true);
}
}

/**
Expand Down Expand Up @@ -258,7 +244,6 @@ public static final class Builder {
Shape shape;
Type type;
NDArray array;
boolean requiresGrad = true;

/**
* Sets the name of the {@code Parameter}.
Expand Down Expand Up @@ -304,17 +289,6 @@ public Builder optArray(NDArray array) {
return this;
}

/**
* Sets if the {@code Parameter} requires gradient.
*
* @param requiresGrad if the {@code Parameter} requires gradient
* @return this {@code Parameter}
*/
public Builder optRequiresGrad(boolean requiresGrad) {
this.requiresGrad = requiresGrad;
return this;
}

/**
* Builds a {@code Parameter} instance.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -530,16 +530,9 @@ private void initBlock() {
String[] allNames = symbol.getAllNames();
mxNetParams = new ArrayList<>(allNames.length);

Set<String> auxNameSet = new HashSet<>(Arrays.asList(symbol.getAuxNames()));
for (String name : allNames) {
Parameter.Type type = inferType(name);
boolean requireGrad = !auxNameSet.contains(name);
mxNetParams.add(
Parameter.builder()
.setName(name)
.setType(type)
.optRequiresGrad(requireGrad)
.build());
mxNetParams.add(Parameter.builder().setName(name).setType(type).build());
}
first = true;
}
Expand Down
Loading

0 comments on commit 720a14d

Please sign in to comment.