diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/javaapi/Context.scala b/scala-package/core/src/main/scala/org/apache/mxnet/javaapi/Context.scala
index 2f4f3e6409ed..ac3517b151f1 100644
--- a/scala-package/core/src/main/scala/org/apache/mxnet/javaapi/Context.scala
+++ b/scala-package/core/src/main/scala/org/apache/mxnet/javaapi/Context.scala
@@ -18,6 +18,13 @@ package org.apache.mxnet.javaapi
import collection.JavaConverters._
+/**
+ * Constructing a context which is used to specify the device and device type that will
+ * be utilized by the engine.
+ *
+ * @param deviceTypeName {'cpu', 'gpu'} String representing the device type
+ * @param deviceId The device id of the device, needed for GPU
+ */
class Context(val context: org.apache.mxnet.Context) {
val deviceTypeid: Int = context.deviceTypeid
@@ -26,6 +33,11 @@ class Context(val context: org.apache.mxnet.Context) {
= this(new org.apache.mxnet.Context(deviceTypeName, deviceId))
def withScope[T](body: => T): T = context.withScope(body)
+
+ /**
+ * Return device type of current context.
+ * @return device_type
+ */
def deviceType: String = context.deviceType
override def toString: String = context.toString
diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/javaapi/IO.scala b/scala-package/core/src/main/scala/org/apache/mxnet/javaapi/IO.scala
index 47b1c367c1c2..bf961b2bd529 100644
--- a/scala-package/core/src/main/scala/org/apache/mxnet/javaapi/IO.scala
+++ b/scala-package/core/src/main/scala/org/apache/mxnet/javaapi/IO.scala
@@ -30,5 +30,13 @@ object DataDesc{
implicit def toDataDesc(dataDesc: DataDesc): org.apache.mxnet.DataDesc = dataDesc.dataDesc
+ /**
+ * Get the dimension that corresponds to the batch size.
+ * @param layout layout string. For example, "NCHW".
+ * @return An axis indicating the batch_size dimension. When data-parallelism is used,
+ * the data will be automatically split and concatenate along the batch_size dimension.
+ * Axis can be -1, which means the whole array will be copied
+ * for each data-parallelism device.
+ */
def getBatchAxis(layout: String): Int = org.apache.mxnet.DataDesc.getBatchAxis(Some(layout))
}
diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/javaapi/NDArray.scala b/scala-package/core/src/main/scala/org/apache/mxnet/javaapi/NDArray.scala
index 96119be84b91..d4e67f73408e 100644
--- a/scala-package/core/src/main/scala/org/apache/mxnet/javaapi/NDArray.scala
+++ b/scala-package/core/src/main/scala/org/apache/mxnet/javaapi/NDArray.scala
@@ -29,27 +29,64 @@ object NDArray extends NDArrayBase {
def waitall(): Unit = org.apache.mxnet.NDArray.waitall()
+ /**
+ * One hot encoding indices into matrix out.
+ * @param indices An NDArray containing indices of the categorical features.
+ * @param out The result holder of the encoding.
+ * @return Same as out.
+ */
def onehotEncode(indices: NDArray, out: NDArray): NDArray
= org.apache.mxnet.NDArray.onehotEncode(indices, out)
+ /**
+ * Create an empty uninitialized new NDArray, with specified shape.
+ *
+ * @param shape shape of the NDArray.
+ * @param ctx The context of the NDArray.
+ *
+ * @return The created NDArray.
+ */
def empty(shape: Shape, ctx: Context, dtype: DType.DType): NDArray
= org.apache.mxnet.NDArray.empty(shape, ctx, dtype)
def empty(ctx: Context, shape: Array[Int]): NDArray
= org.apache.mxnet.NDArray.empty(new Shape(shape), ctx)
def empty(ctx : Context, shape : java.util.List[java.lang.Integer]) : NDArray
= org.apache.mxnet.NDArray.empty(new Shape(shape), ctx)
+
+ /**
+ * Create a new NDArray filled with 0, with specified shape.
+ *
+ * @param shape shape of the NDArray.
+ * @param ctx The context of the NDArray.
+ *
+ * @return The created NDArray.
+ */
def zeros(shape: Shape, ctx: Context, dtype: DType.DType): NDArray
= org.apache.mxnet.NDArray.zeros(shape, ctx, dtype)
def zeros(ctx: Context, shape: Array[Int]): NDArray
= org.apache.mxnet.NDArray.zeros(new Shape(shape), ctx)
def zeros(ctx : Context, shape : java.util.List[java.lang.Integer]) : NDArray
= org.apache.mxnet.NDArray.zeros(new Shape(shape), ctx)
+
+ /**
+ * Create a new NDArray filled with 1, with specified shape.
+ * @param shape shape of the NDArray.
+ * @param ctx The context of the NDArray.
+ * @return The created NDArray.
+ */
def ones(shape: Shape, ctx: Context, dtype: DType.DType): NDArray
= org.apache.mxnet.NDArray.ones(shape, ctx, dtype)
def ones(ctx: Context, shape: Array[Int]): NDArray
= org.apache.mxnet.NDArray.ones(new Shape(shape), ctx)
def ones(ctx : Context, shape : java.util.List[java.lang.Integer]) : NDArray
= org.apache.mxnet.NDArray.ones(new Shape(shape), ctx)
+
+ /**
+ * Create a new NDArray filled with given value, with specified shape.
+ * @param shape shape of the NDArray.
+ * @param value value to be filled with
+ * @param ctx The context of the NDArray
+ */
def full(shape: Shape, value: Float, ctx: Context): NDArray
= org.apache.mxnet.NDArray.full(shape, value, ctx)
@@ -65,37 +102,102 @@ object NDArray extends NDArrayBase {
def minimum(lhs: NDArray, rhs: Float): NDArray = org.apache.mxnet.NDArray.minimum(lhs, rhs)
def minimum(lhs: Float, rhs: NDArray): NDArray = org.apache.mxnet.NDArray.minimum(lhs, rhs)
+
+ /**
+ * Returns the result of element-wise **equal to** (==) comparison operation with broadcasting.
+ * For each element in input arrays, return 1(true) if corresponding elements are same,
+ * otherwise return 0(false).
+ */
def equal(lhs: NDArray, rhs: NDArray): NDArray = org.apache.mxnet.NDArray.equal(lhs, rhs)
def equal(lhs: NDArray, rhs: Float): NDArray = org.apache.mxnet.NDArray.equal(lhs, rhs)
+ /**
+ * Returns the result of element-wise **not equal to** (!=) comparison operation
+ * with broadcasting.
+ * For each element in input arrays, return 1(true) if corresponding elements are different,
+ * otherwise return 0(false).
+ */
def notEqual(lhs: NDArray, rhs: NDArray): NDArray = org.apache.mxnet.NDArray.notEqual(lhs, rhs)
def notEqual(lhs: NDArray, rhs: Float): NDArray = org.apache.mxnet.NDArray.notEqual(lhs, rhs)
+ /**
+ * Returns the result of element-wise **greater than** (>) comparison operation
+ * with broadcasting.
+ * For each element in input arrays, return 1(true) if lhs elements are greater than rhs,
+ * otherwise return 0(false).
+ */
def greater(lhs: NDArray, rhs: NDArray): NDArray = org.apache.mxnet.NDArray.greater(lhs, rhs)
def greater(lhs: NDArray, rhs: Float): NDArray = org.apache.mxnet.NDArray.greater(lhs, rhs)
+ /**
+ * Returns the result of element-wise **greater than or equal to** (>=) comparison
+ * operation with broadcasting.
+ * For each element in input arrays, return 1(true) if lhs elements are greater than equal to rhs
+ * otherwise return 0(false).
+ */
def greaterEqual(lhs: NDArray, rhs: NDArray): NDArray
= org.apache.mxnet.NDArray.greaterEqual(lhs, rhs)
def greaterEqual(lhs: NDArray, rhs: Float): NDArray
= org.apache.mxnet.NDArray.greaterEqual(lhs, rhs)
+ /**
+ * Returns the result of element-wise **lesser than** (<) comparison operation
+ * with broadcasting.
+ * For each element in input arrays, return 1(true) if lhs elements are less than rhs,
+ * otherwise return 0(false).
+ */
def lesser(lhs: NDArray, rhs: NDArray): NDArray = org.apache.mxnet.NDArray.lesser(lhs, rhs)
def lesser(lhs: NDArray, rhs: Float): NDArray = org.apache.mxnet.NDArray.lesser(lhs, rhs)
+ /**
+ * Returns the result of element-wise **lesser than or equal to** (<=) comparison
+ * operation with broadcasting.
+ * For each element in input arrays, return 1(true) if lhs elements are
+ * lesser than equal to rhs, otherwise return 0(false).
+ */
def lesserEqual(lhs: NDArray, rhs: NDArray): NDArray
= org.apache.mxnet.NDArray.lesserEqual(lhs, rhs)
def lesserEqual(lhs: NDArray, rhs: Float): NDArray
= org.apache.mxnet.NDArray.lesserEqual(lhs, rhs)
+ /**
+ * Create a new NDArray that copies content from source_array.
+ * @param sourceArr Source data to create NDArray from.
+ * @param shape shape of the NDArray
+ * @param ctx The context of the NDArray, default to current default context.
+ * @return The created NDArray.
+ */
def array(sourceArr: java.util.List[java.lang.Float], shape: Shape, ctx: Context = null): NDArray
= org.apache.mxnet.NDArray.array(
sourceArr.asScala.map(ele => Float.unbox(ele)).toArray, shape, ctx)
+ /**
+ * Returns evenly spaced values within a given interval.
+ * Values are generated within the half-open interval [`start`, `stop`). In other
+ * words, the interval includes `start` but excludes `stop`.
+ * @param start Start of interval.
+ * @param stop End of interval.
+ * @param step Spacing between values.
+ * @param repeat Number of times to repeat each element.
+ * @param ctx Device context.
+ * @param dType The data type of the `NDArray`.
+ * @return NDArray of evenly spaced values in the specified range.
+ */
def arange(start: Float, stop: Float, step: Float, repeat: Int,
ctx: Context, dType: DType.DType): NDArray =
org.apache.mxnet.NDArray.arange(start, Some(stop), step, repeat, ctx, dType)
}
+/**
+ * NDArray object in mxnet.
+ * NDArray is basic ndarray/Tensor like data structure in mxnet.
+ *
+ * NOTE: NDArray is stored in native memory. Use NDArray in a try-with-resources() construct
+ * or a [[ResourceScope]] in a try-with-resource to have them automatically disposed. You can
+ * explicitly control the lifetime of NDArray by calling dispose manually. Failure to do this
+ * will result in leaking native memory.
+ *
+ */
class NDArray(val nd : org.apache.mxnet.NDArray ) {
def this(arr : Array[Float], shape : Shape, ctx : Context) = {
@@ -108,28 +210,88 @@ class NDArray(val nd : org.apache.mxnet.NDArray ) {
def serialize() : Array[Byte] = nd.serialize()
+ /**
+ * Release the native memory.
+ * The NDArrays it depends on will NOT be disposed.
+ * The object shall never be used after it is disposed.
+ */
def dispose() : Unit = nd.dispose()
+
+ /**
+ * Dispose all NDArrays who help to construct this array.
+ * e.g. (a * b + c).disposeDeps() will dispose a, b, c (including their deps) and a * b
+ * @return this array
+ */
def disposeDeps() : NDArray = nd.disposeDepsExcept()
// def disposeDepsExcept(arr : Array[NDArray]) : NDArray = nd.disposeDepsExcept()
+ /**
+ * Return a sliced NDArray that shares memory with current one.
+ * NDArray only support continuous slicing on axis 0
+ *
+ * @param start Starting index of slice.
+ * @param stop Finishing index of slice.
+ *
+ * @return a sliced NDArray that shares memory with current one.
+ */
def slice(start : Int, stop : Int) : NDArray = nd.slice(start, stop)
+ /**
+ * Return a sliced NDArray at the ith position of axis0
+ * @param i
+ * @return a sliced NDArray that shares memory with current one.
+ */
def slice (i : Int) : NDArray = nd.slice(i)
+ /**
+ * Return a sub NDArray that shares memory with current one.
+ * the first axis will be rolled up, which causes its shape different from slice(i, i+1)
+ * @param idx index of sub array.
+ */
def at(idx : Int) : NDArray = nd.at(idx)
def T : NDArray = nd.T
+ /**
+ * Get data type of current NDArray.
+ * @return class representing type of current ndarray
+ */
def dtype : DType = nd.dtype
+ /**
+ * Return a copied numpy array of current array with specified type.
+ * @param dtype Desired type of result array.
+ * @return A copy of array content.
+ */
def asType(dtype : DType) : NDArray = nd.asType(dtype)
+ /**
+ * Return a reshaped NDArray that shares memory with current one.
+ * @param dims New shape.
+ *
+ * @return a reshaped NDArray that shares memory with current one.
+ */
def reshape(dims : Array[Int]) : NDArray = nd.reshape(dims)
+ /**
+ * Block until all pending writes operations on current NDArray are finished.
+ * This function will return when all the pending writes to the current
+ * NDArray finishes. There can still be pending read going on when the
+ * function returns.
+ */
def waitToRead(): Unit = nd.waitToRead()
+ /**
+ * Get context of current NDArray.
+ * @return The context of current NDArray.
+ */
def context : Context = nd.context
+ /**
+ * Set the values of the NDArray
+ * @param value Value to set
+ * @return Current NDArray
+ */
def set(value : Float) : NDArray = nd.set(value)
def set(other : NDArray) : NDArray = nd.set(other)
def set(other : Array[Float]) : NDArray = nd.set(other)
@@ -167,20 +329,57 @@ class NDArray(val nd : org.apache.mxnet.NDArray ) {
def lesserEqual(other : NDArray) : NDArray = this.nd <= other
def lesserEqual(other : Float) : NDArray = this.nd <= other
+ /**
+ * Return a copied flat java array of current array (row-major).
+ * @return A copy of array content.
+ */
def toArray : Array[Float] = nd.toArray
+ /**
+ * Return a CPU scalar(float) of current ndarray.
+ * This ndarray must have shape (1,)
+ *
+ * @return The scalar representation of the ndarray.
+ */
def toScalar : Float = nd.toScalar
+ /**
+ * Copy the content of current array to other.
+ *
+ * @param other Target NDArray or context we want to copy data to.
+ * @return The copy target NDArray
+ */
def copyTo(other : NDArray) : NDArray = nd.copyTo(other)
+ /**
+ * Copy the content of current array to a new NDArray in the context.
+ *
+ * @param ctx Target context we want to copy data to.
+ * @return The copy target NDArray
+ */
def copyTo(ctx : Context) : NDArray = nd.copyTo(ctx)
+ /**
+ * Clone the current array
+ * @return the copied NDArray in the same context
+ */
def copy() : NDArray = copyTo(this.context)
+ /**
+ * Get shape of current NDArray.
+ * @return an array representing shape of current ndarray
+ */
def shape : Shape = nd.shape
+
def size : Int = shape.product
+ /**
+ * Return an `NDArray` that lives in the target context. If the array
+ * is already in that context, `self` is returned. Otherwise, a copy is made.
+ * @param context The target context we want the return value to live in.
+ * @return A copy or `self` as an `NDArray` that lives in the target context.
+ */
def asInContext(context: Context): NDArray = nd.asInContext(context)
override def equals(obj: Any): Boolean = nd.equals(obj)
diff --git a/scala-package/infer/src/main/scala/org/apache/mxnet/infer/javaapi/ObjectDetector.scala b/scala-package/infer/src/main/scala/org/apache/mxnet/infer/javaapi/ObjectDetector.scala
index 6cd3df6b896b..f48375ffe4a7 100644
--- a/scala-package/infer/src/main/scala/org/apache/mxnet/infer/javaapi/ObjectDetector.scala
+++ b/scala-package/infer/src/main/scala/org/apache/mxnet/infer/javaapi/ObjectDetector.scala
@@ -26,7 +26,21 @@ import org.apache.mxnet.javaapi.{Context, DataDesc, NDArray}
import scala.collection.JavaConverters
import scala.collection.JavaConverters._
-
+/**
+ * The ObjectDetector class helps to run ObjectDetection tasks where the goal
+ * is to find bounding boxes and corresponding labels for objects in a image.
+ *
+ * @param modelPathPrefix Path prefix from where to load the model artifacts.
+ * These include the symbol, parameters, and synset.txt.
+ * Example: file://model-dir/ssd_resnet50_512 (containing
+ * ssd_resnet50_512-symbol.json, ssd_resnet50_512-0000.params,
+ * and synset.txt)
+ * @param inputDescriptors Descriptors defining the input node names, shape,
+ * layout and type parameters
+ * @param contexts Device contexts on which you want to run inference.
+ * Defaults to CPU.
+ * @param epoch Model epoch to load; defaults to 0
+ */
class ObjectDetector(val objDetector: org.apache.mxnet.infer.ObjectDetector){
def this(modelPathPrefix: String, inputDescriptors: java.util.List[DataDesc], contexts:
diff --git a/scala-package/infer/src/main/scala/org/apache/mxnet/infer/javaapi/Predictor.scala b/scala-package/infer/src/main/scala/org/apache/mxnet/infer/javaapi/Predictor.scala
index 26ccd06cf466..3e0fcb7b507a 100644
--- a/scala-package/infer/src/main/scala/org/apache/mxnet/infer/javaapi/Predictor.scala
+++ b/scala-package/infer/src/main/scala/org/apache/mxnet/infer/javaapi/Predictor.scala
@@ -22,6 +22,23 @@ import org.apache.mxnet.javaapi.{Context, DataDesc, NDArray}
import scala.collection.JavaConverters
import scala.collection.JavaConverters._
+/**
+ * Implementation of prediction routines.
+ *
+ * @param modelPathPrefix Path prefix from where to load the model artifacts.
+ * These include the symbol, parameters, and synset.txt
+ * Example: file://model-dir/resnet-152 (containing
+ * resnet-152-symbol.json, resnet-152-0000.params, and synset.txt).
+ * @param inputDescriptors Descriptors defining the input node names, shape,
+ * layout and type parameters
+ *
Note: If the input Descriptors is missing batchSize + * ('N' in layout), a batchSize of 1 is assumed for the model. + * @param contexts Device contexts on which you want to run inference; defaults to CPU + * @param epoch Model epoch to load; defaults to 0 + + */ + +// JavaDoc description of class to be updated in https://issues.apache.org/jira/browse/MXNET-1178 class Predictor(val predictor: org.apache.mxnet.infer.Predictor){ def this(modelPathPrefix: String, inputDescriptors: java.util.List[DataDesc], contexts: java.util.List[Context], epoch: Int)