Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
First pass at adding JavaDocs for new java api classes (#12963)
Browse files Browse the repository at this point in the history
* First pass at adding JavaDocs for new java api classes

* Fix a scalastyle issue

* Updating JavaDoc based on feedback
  • Loading branch information
andrewfayres authored and nswamy committed Oct 26, 2018
1 parent 5aaa729 commit 743301c
Show file tree
Hide file tree
Showing 5 changed files with 251 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,13 @@ package org.apache.mxnet.javaapi

import collection.JavaConverters._

/**
* Constructing a context which is used to specify the device and device type that will
* be utilized by the engine.
*
* @param deviceTypeName {'cpu', 'gpu'} String representing the device type
* @param deviceId The device id of the device, needed for GPU
*/
class Context(val context: org.apache.mxnet.Context) {

val deviceTypeid: Int = context.deviceTypeid
Expand All @@ -26,6 +33,11 @@ class Context(val context: org.apache.mxnet.Context) {
= this(new org.apache.mxnet.Context(deviceTypeName, deviceId))

def withScope[T](body: => T): T = context.withScope(body)

/**
* Return device type of current context.
* @return device_type
*/
def deviceType: String = context.deviceType

override def toString: String = context.toString
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,5 +30,13 @@ object DataDesc{

implicit def toDataDesc(dataDesc: DataDesc): org.apache.mxnet.DataDesc = dataDesc.dataDesc

/**
* Get the dimension that corresponds to the batch size.
* @param layout layout string. For example, "NCHW".
* @return An axis indicating the batch_size dimension. When data-parallelism is used,
* the data will be automatically split and concatenate along the batch_size dimension.
* Axis can be -1, which means the whole array will be copied
* for each data-parallelism device.
*/
def getBatchAxis(layout: String): Int = org.apache.mxnet.DataDesc.getBatchAxis(Some(layout))
}
Original file line number Diff line number Diff line change
Expand Up @@ -29,27 +29,64 @@ object NDArray extends NDArrayBase {

def waitall(): Unit = org.apache.mxnet.NDArray.waitall()

/**
* One hot encoding indices into matrix out.
* @param indices An NDArray containing indices of the categorical features.
* @param out The result holder of the encoding.
* @return Same as out.
*/
def onehotEncode(indices: NDArray, out: NDArray): NDArray
= org.apache.mxnet.NDArray.onehotEncode(indices, out)

/**
* Create an empty uninitialized new NDArray, with specified shape.
*
* @param shape shape of the NDArray.
* @param ctx The context of the NDArray.
*
* @return The created NDArray.
*/
def empty(shape: Shape, ctx: Context, dtype: DType.DType): NDArray
= org.apache.mxnet.NDArray.empty(shape, ctx, dtype)
def empty(ctx: Context, shape: Array[Int]): NDArray
= org.apache.mxnet.NDArray.empty(new Shape(shape), ctx)
def empty(ctx : Context, shape : java.util.List[java.lang.Integer]) : NDArray
= org.apache.mxnet.NDArray.empty(new Shape(shape), ctx)

/**
* Create a new NDArray filled with 0, with specified shape.
*
* @param shape shape of the NDArray.
* @param ctx The context of the NDArray.
*
* @return The created NDArray.
*/
def zeros(shape: Shape, ctx: Context, dtype: DType.DType): NDArray
= org.apache.mxnet.NDArray.zeros(shape, ctx, dtype)
def zeros(ctx: Context, shape: Array[Int]): NDArray
= org.apache.mxnet.NDArray.zeros(new Shape(shape), ctx)
def zeros(ctx : Context, shape : java.util.List[java.lang.Integer]) : NDArray
= org.apache.mxnet.NDArray.zeros(new Shape(shape), ctx)

/**
* Create a new NDArray filled with 1, with specified shape.
* @param shape shape of the NDArray.
* @param ctx The context of the NDArray.
* @return The created NDArray.
*/
def ones(shape: Shape, ctx: Context, dtype: DType.DType): NDArray
= org.apache.mxnet.NDArray.ones(shape, ctx, dtype)
def ones(ctx: Context, shape: Array[Int]): NDArray
= org.apache.mxnet.NDArray.ones(new Shape(shape), ctx)
def ones(ctx : Context, shape : java.util.List[java.lang.Integer]) : NDArray
= org.apache.mxnet.NDArray.ones(new Shape(shape), ctx)

/**
* Create a new NDArray filled with given value, with specified shape.
* @param shape shape of the NDArray.
* @param value value to be filled with
* @param ctx The context of the NDArray
*/
def full(shape: Shape, value: Float, ctx: Context): NDArray
= org.apache.mxnet.NDArray.full(shape, value, ctx)

Expand All @@ -65,37 +102,102 @@ object NDArray extends NDArrayBase {
def minimum(lhs: NDArray, rhs: Float): NDArray = org.apache.mxnet.NDArray.minimum(lhs, rhs)
def minimum(lhs: Float, rhs: NDArray): NDArray = org.apache.mxnet.NDArray.minimum(lhs, rhs)


/**
* Returns the result of element-wise **equal to** (==) comparison operation with broadcasting.
* For each element in input arrays, return 1(true) if corresponding elements are same,
* otherwise return 0(false).
*/
def equal(lhs: NDArray, rhs: NDArray): NDArray = org.apache.mxnet.NDArray.equal(lhs, rhs)
def equal(lhs: NDArray, rhs: Float): NDArray = org.apache.mxnet.NDArray.equal(lhs, rhs)

/**
* Returns the result of element-wise **not equal to** (!=) comparison operation
* with broadcasting.
* For each element in input arrays, return 1(true) if corresponding elements are different,
* otherwise return 0(false).
*/
def notEqual(lhs: NDArray, rhs: NDArray): NDArray = org.apache.mxnet.NDArray.notEqual(lhs, rhs)
def notEqual(lhs: NDArray, rhs: Float): NDArray = org.apache.mxnet.NDArray.notEqual(lhs, rhs)

/**
* Returns the result of element-wise **greater than** (>) comparison operation
* with broadcasting.
* For each element in input arrays, return 1(true) if lhs elements are greater than rhs,
* otherwise return 0(false).
*/
def greater(lhs: NDArray, rhs: NDArray): NDArray = org.apache.mxnet.NDArray.greater(lhs, rhs)
def greater(lhs: NDArray, rhs: Float): NDArray = org.apache.mxnet.NDArray.greater(lhs, rhs)

/**
* Returns the result of element-wise **greater than or equal to** (>=) comparison
* operation with broadcasting.
* For each element in input arrays, return 1(true) if lhs elements are greater than equal to rhs
* otherwise return 0(false).
*/
def greaterEqual(lhs: NDArray, rhs: NDArray): NDArray
= org.apache.mxnet.NDArray.greaterEqual(lhs, rhs)
def greaterEqual(lhs: NDArray, rhs: Float): NDArray
= org.apache.mxnet.NDArray.greaterEqual(lhs, rhs)

/**
* Returns the result of element-wise **lesser than** (<) comparison operation
* with broadcasting.
* For each element in input arrays, return 1(true) if lhs elements are less than rhs,
* otherwise return 0(false).
*/
def lesser(lhs: NDArray, rhs: NDArray): NDArray = org.apache.mxnet.NDArray.lesser(lhs, rhs)
def lesser(lhs: NDArray, rhs: Float): NDArray = org.apache.mxnet.NDArray.lesser(lhs, rhs)

/**
* Returns the result of element-wise **lesser than or equal to** (<=) comparison
* operation with broadcasting.
* For each element in input arrays, return 1(true) if lhs elements are
* lesser than equal to rhs, otherwise return 0(false).
*/
def lesserEqual(lhs: NDArray, rhs: NDArray): NDArray
= org.apache.mxnet.NDArray.lesserEqual(lhs, rhs)
def lesserEqual(lhs: NDArray, rhs: Float): NDArray
= org.apache.mxnet.NDArray.lesserEqual(lhs, rhs)

/**
* Create a new NDArray that copies content from source_array.
* @param sourceArr Source data to create NDArray from.
* @param shape shape of the NDArray
* @param ctx The context of the NDArray, default to current default context.
* @return The created NDArray.
*/
def array(sourceArr: java.util.List[java.lang.Float], shape: Shape, ctx: Context = null): NDArray
= org.apache.mxnet.NDArray.array(
sourceArr.asScala.map(ele => Float.unbox(ele)).toArray, shape, ctx)

/**
* Returns evenly spaced values within a given interval.
* Values are generated within the half-open interval [`start`, `stop`). In other
* words, the interval includes `start` but excludes `stop`.
* @param start Start of interval.
* @param stop End of interval.
* @param step Spacing between values.
* @param repeat Number of times to repeat each element.
* @param ctx Device context.
* @param dType The data type of the `NDArray`.
* @return NDArray of evenly spaced values in the specified range.
*/
def arange(start: Float, stop: Float, step: Float, repeat: Int,
ctx: Context, dType: DType.DType): NDArray =
org.apache.mxnet.NDArray.arange(start, Some(stop), step, repeat, ctx, dType)
}

/**
* NDArray object in mxnet.
* NDArray is basic ndarray/Tensor like data structure in mxnet. <br />
* <b>
* NOTE: NDArray is stored in native memory. Use NDArray in a try-with-resources() construct
* or a [[ResourceScope]] in a try-with-resource to have them automatically disposed. You can
* explicitly control the lifetime of NDArray by calling dispose manually. Failure to do this
* will result in leaking native memory.
* </b>
*/
class NDArray(val nd : org.apache.mxnet.NDArray ) {

def this(arr : Array[Float], shape : Shape, ctx : Context) = {
Expand All @@ -108,28 +210,88 @@ class NDArray(val nd : org.apache.mxnet.NDArray ) {

def serialize() : Array[Byte] = nd.serialize()

/**
* Release the native memory. <br />
* The NDArrays it depends on will NOT be disposed. <br />
* The object shall never be used after it is disposed.
*/
def dispose() : Unit = nd.dispose()

/**
* Dispose all NDArrays who help to construct this array. <br />
* e.g. (a * b + c).disposeDeps() will dispose a, b, c (including their deps) and a * b
* @return this array
*/
def disposeDeps() : NDArray = nd.disposeDepsExcept()
// def disposeDepsExcept(arr : Array[NDArray]) : NDArray = nd.disposeDepsExcept()

/**
* Return a sliced NDArray that shares memory with current one.
* NDArray only support continuous slicing on axis 0
*
* @param start Starting index of slice.
* @param stop Finishing index of slice.
*
* @return a sliced NDArray that shares memory with current one.
*/
def slice(start : Int, stop : Int) : NDArray = nd.slice(start, stop)

/**
* Return a sliced NDArray at the ith position of axis0
* @param i
* @return a sliced NDArray that shares memory with current one.
*/
def slice (i : Int) : NDArray = nd.slice(i)

/**
* Return a sub NDArray that shares memory with current one.
* the first axis will be rolled up, which causes its shape different from slice(i, i+1)
* @param idx index of sub array.
*/
def at(idx : Int) : NDArray = nd.at(idx)

def T : NDArray = nd.T

/**
* Get data type of current NDArray.
* @return class representing type of current ndarray
*/
def dtype : DType = nd.dtype

/**
* Return a copied numpy array of current array with specified type.
* @param dtype Desired type of result array.
* @return A copy of array content.
*/
def asType(dtype : DType) : NDArray = nd.asType(dtype)

/**
* Return a reshaped NDArray that shares memory with current one.
* @param dims New shape.
*
* @return a reshaped NDArray that shares memory with current one.
*/
def reshape(dims : Array[Int]) : NDArray = nd.reshape(dims)

/**
* Block until all pending writes operations on current NDArray are finished.
* This function will return when all the pending writes to the current
* NDArray finishes. There can still be pending read going on when the
* function returns.
*/
def waitToRead(): Unit = nd.waitToRead()

/**
* Get context of current NDArray.
* @return The context of current NDArray.
*/
def context : Context = nd.context

/**
* Set the values of the NDArray
* @param value Value to set
* @return Current NDArray
*/
def set(value : Float) : NDArray = nd.set(value)
def set(other : NDArray) : NDArray = nd.set(other)
def set(other : Array[Float]) : NDArray = nd.set(other)
Expand Down Expand Up @@ -167,20 +329,57 @@ class NDArray(val nd : org.apache.mxnet.NDArray ) {
def lesserEqual(other : NDArray) : NDArray = this.nd <= other
def lesserEqual(other : Float) : NDArray = this.nd <= other

/**
* Return a copied flat java array of current array (row-major).
* @return A copy of array content.
*/
def toArray : Array[Float] = nd.toArray

/**
* Return a CPU scalar(float) of current ndarray.
* This ndarray must have shape (1,)
*
* @return The scalar representation of the ndarray.
*/
def toScalar : Float = nd.toScalar

/**
* Copy the content of current array to other.
*
* @param other Target NDArray or context we want to copy data to.
* @return The copy target NDArray
*/
def copyTo(other : NDArray) : NDArray = nd.copyTo(other)

/**
* Copy the content of current array to a new NDArray in the context.
*
* @param ctx Target context we want to copy data to.
* @return The copy target NDArray
*/
def copyTo(ctx : Context) : NDArray = nd.copyTo(ctx)

/**
* Clone the current array
* @return the copied NDArray in the same context
*/
def copy() : NDArray = copyTo(this.context)

/**
* Get shape of current NDArray.
* @return an array representing shape of current ndarray
*/
def shape : Shape = nd.shape


def size : Int = shape.product

/**
* Return an `NDArray` that lives in the target context. If the array
* is already in that context, `self` is returned. Otherwise, a copy is made.
* @param context The target context we want the return value to live in.
* @return A copy or `self` as an `NDArray` that lives in the target context.
*/
def asInContext(context: Context): NDArray = nd.asInContext(context)

override def equals(obj: Any): Boolean = nd.equals(obj)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,21 @@ import org.apache.mxnet.javaapi.{Context, DataDesc, NDArray}
import scala.collection.JavaConverters
import scala.collection.JavaConverters._


/**
* The ObjectDetector class helps to run ObjectDetection tasks where the goal
* is to find bounding boxes and corresponding labels for objects in a image.
*
* @param modelPathPrefix Path prefix from where to load the model artifacts.
* These include the symbol, parameters, and synset.txt.
* Example: file://model-dir/ssd_resnet50_512 (containing
* ssd_resnet50_512-symbol.json, ssd_resnet50_512-0000.params,
* and synset.txt)
* @param inputDescriptors Descriptors defining the input node names, shape,
* layout and type parameters
* @param contexts Device contexts on which you want to run inference.
* Defaults to CPU.
* @param epoch Model epoch to load; defaults to 0
*/
class ObjectDetector(val objDetector: org.apache.mxnet.infer.ObjectDetector){

def this(modelPathPrefix: String, inputDescriptors: java.util.List[DataDesc], contexts:
Expand Down
Loading

0 comments on commit 743301c

Please sign in to comment.