Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

[WIP][MXNET-918] Random api #12489

Closed
wants to merge 8 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@ object NDArray extends NDArrayBase {

val api = NDArrayAPI

val random = NDArrayRandomAPI

private def addDependency(froms: Array[NDArray], tos: Array[NDArray]): Unit = {
froms.foreach { from =>
val weakRef = new WeakReference(from)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,3 +23,13 @@ package org.apache.mxnet
object NDArrayAPI extends NDArrayAPIBase {
// TODO: Implement CustomOp for NDArray
}

@AddNDArrayRandomAPIs(false)
/**
* typesafe NDArray random module: NDArray.random._
* Main code will be generated during compile time through Macros
*/
object NDArrayRandomAPI extends NDArrayRandomAPIBase {

}

Original file line number Diff line number Diff line change
Expand Up @@ -840,8 +840,12 @@ object Symbol extends SymbolBase {
private val functions: Map[String, SymbolFunction] = initSymbolModule()
private val bindReqMap = Map("null" -> 0, "write" -> 1, "add" -> 3)

type SymbolOrFloat = Any

val api = SymbolAPI

val random = SymbolRandomAPI

def pow(sym1: Symbol, sym2: Symbol): Symbol = {
Symbol.createFromListedSymbols("_Power")(Array(sym1, sym2))
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,3 +32,13 @@ object SymbolAPI extends SymbolAPIBase {
Symbol.createSymbolGeneral("Custom", name, attr, Seq(), map.toMap)
}
}

@AddSymbolRandomAPIs(false)
/**
* typesafe Symbol random module: Symbol.random._
* Main code will be generated during compile time through Macros
*/
object SymbolRandomAPI extends SymbolRandomAPIBase {

}

Original file line number Diff line number Diff line change
Expand Up @@ -576,4 +576,22 @@ class NDArraySuite extends FunSuite with BeforeAndAfterAll with Matchers {
assert(arr.internal.toDoubleArray === Array(2d, 2d))
assert(arr.internal.toByteArray === Array(2.toByte, 2.toByte))
}

test("random module is generated properly") {
val lam = NDArray.ones(1, 2)
val rnd = NDArray.random.poisson(lam = Some(lam), shape = Some(Shape(3, 4)))
val rnd2 = NDArray.random.poisson(lam = Some(1f), shape = Some(Shape(3, 4)))
assert(rnd.shape === Shape(1, 2, 3, 4))
assert(rnd2.shape === Shape(3, 4))
}

test("random module is generated properly - special case of 'normal'") {
val mu = NDArray.ones(1, 2)
val sigma = NDArray.ones(1, 2) * 2
val rnd = NDArray.random.normal(mu = Some(mu), sigma = Some(sigma), shape = Some(Shape(3, 4)))
val rnd2 = NDArray.random.normal(mu = Some(1f), sigma = Some(2f),
shape = Some(Shape(3, 4)))
assert(rnd.shape === Shape(1, 2, 3, 4))
assert(rnd2.shape === Shape(3, 4))
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,9 @@ package org.apache.mxnet

import org.scalatest.{BeforeAndAfterAll, FunSuite}


class SymbolSuite extends FunSuite with BeforeAndAfterAll {

test("symbol compose") {
val data = Symbol.Variable("data")

Expand Down Expand Up @@ -71,4 +73,27 @@ class SymbolSuite extends FunSuite with BeforeAndAfterAll {
val data2 = data.clone()
assert(data.toJson === data2.toJson)
}

test("random module is generated properly") {
val lam = Symbol.Variable("lam")
val rnd = Symbol.random.poisson(lam = Some(lam), shape = Some(Shape(2, 2)))
val rnd2 = Symbol.random.poisson(lam = Some(1f), shape = Some(Shape(2, 2)))
// scalastyle:off println
println(s"Symbol.random.poisson debug info: ${rnd.debugStr}")
println(s"Symbol.random.poisson debug info: ${rnd2.debugStr}")
// scalastyle:on println
}

test("random module is generated properly - special case of 'normal'") {
val loc = Symbol.Variable("loc")
val scale = Symbol.Variable("scale")
val rnd = Symbol.random.normal(mu = Some(loc), sigma = Some(scale),
shape = Some(Shape(2, 2)))
val rnd2 = Symbol.random.normal(mu = Some(1f), sigma = Some(2f),
shape = Some(Shape(2, 2)))
// scalastyle:off println
println(s"Symbol.random.sample_normal debug info: ${rnd.debugStr}")
println(s"Symbol.random.random_normal debug info: ${rnd2.debugStr}")
// scalastyle:on println
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,6 @@

package org.apache.mxnet

import org.apache.mxnet.init.Base._
import org.apache.mxnet.utils.CToScalaUtils
import java.io._
import java.security.MessageDigest

Expand All @@ -29,98 +27,131 @@ import scala.collection.mutable.{ArrayBuffer, ListBuffer}
* Two file namely: SymbolAPIBase.scala and NDArrayAPIBase.scala
* The code will be executed during Macros stage and file live in Core stage
*/
private[mxnet] object APIDocGenerator{
case class absClassArg(argName : String, argType : String, argDesc : String, isOptional : Boolean)
case class absClassFunction(name : String, desc : String,
listOfArgs: List[absClassArg], returnType : String)
private[mxnet] object APIDocGenerator extends GeneratorBase {
type absClassArg = Arg
type absClassFunction = Func


def main(args: Array[String]) : Unit = {
def main(args: Array[String]): Unit = {
val FILE_PATH = args(0)
val hashCollector = ListBuffer[String]()
hashCollector += absClassGen(FILE_PATH, true)
hashCollector += absClassGen(FILE_PATH, false)
hashCollector += absRndClassGen(FILE_PATH, true)
hashCollector += absRndClassGen(FILE_PATH, false)
hashCollector += nonTypeSafeClassGen(FILE_PATH, true)
hashCollector += nonTypeSafeClassGen(FILE_PATH, false)
val finalHash = hashCollector.mkString("\n")
}

def MD5Generator(input : String) : String = {
def MD5Generator(input: String): String = {
val md = MessageDigest.getInstance("MD5")
md.update(input.getBytes("UTF-8"))
val digest = md.digest()
org.apache.commons.codec.binary.Base64.encodeBase64URLSafeString(digest)
}

def absClassGen(FILE_PATH : String, isSymbol : Boolean) : String = {
// scalastyle:off
val absClassFunctions = getSymbolNDArrayMethods(isSymbol)
// Defines Operators that should not generated
def absRndClassGen(FILE_PATH: String, isSymbol: Boolean): String = {
val funcs = buildRandomFunctionList(isSymbol)

val body = funcs.map(func => {
val scalaDoc = generateAPIDocFromBackend(func)
val decl = generateAPISignature(func, isSymbol)
s"$scalaDoc\n$decl"
})
writeFile(
FILE_PATH,
if (isSymbol) "SymbolRandomAPIBase" else "NDArrayRandomAPIBase",
body)
}

def absClassGen(FILE_PATH: String, isSymbol: Boolean): String = {
val notGenerated = Set("Custom")
// TODO: Add Filter to the same location in case of refactor
val absFuncs = absClassFunctions.filterNot(_.name.startsWith("_"))
val funcs = buildFunctionList(isSymbol)
.filterNot(_.name.startsWith("_"))
.filterNot(ele => notGenerated.contains(ele.name))
.map(absClassFunction => {
val scalaDoc = generateAPIDocFromBackend(absClassFunction)
val defBody = generateAPISignature(absClassFunction, isSymbol)
s"$scalaDoc\n$defBody"
val body = funcs.map(func => {
val scalaDoc = generateAPIDocFromBackend(func)
val decl = generateAPISignature(func, isSymbol)
s"$scalaDoc\n$decl"
})
val packageName = if (isSymbol) "SymbolAPIBase" else "NDArrayAPIBase"
val apacheLicence = "/*\n* Licensed to the Apache Software Foundation (ASF) under one or more\n* contributor license agreements. See the NOTICE file distributed with\n* this work for additional information regarding copyright ownership.\n* The ASF licenses this file to You under the Apache License, Version 2.0\n* (the \"License\"); you may not use this file except in compliance with\n* the License. You may obtain a copy of the License at\n*\n* http://www.apache.org/licenses/LICENSE-2.0\n*\n* Unless required by applicable law or agreed to in writing, software\n* distributed under the License is distributed on an \"AS IS\" BASIS,\n* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n* See the License for the specific language governing permissions and\n* limitations under the License.\n*/\n"
val scalaStyle = "// scalastyle:off"
val packageDef = "package org.apache.mxnet"
val imports = "import org.apache.mxnet.annotation.Experimental"
val absClassDef = s"abstract class $packageName"
val finalStr = s"$apacheLicence\n$scalaStyle\n$packageDef\n$imports\n$absClassDef {\n${absFuncs.mkString("\n")}\n}"
val pw = new PrintWriter(new File(FILE_PATH + s"$packageName.scala"))
pw.write(finalStr)
pw.close()
MD5Generator(finalStr)
writeFile(
FILE_PATH,
if (isSymbol) "SymbolAPIBase" else "NDArrayAPIBase",
body)
}

def nonTypeSafeClassGen(FILE_PATH : String, isSymbol : Boolean) : String = {
// scalastyle:off
val absClassFunctions = getSymbolNDArrayMethods(isSymbol)
val absFuncs = absClassFunctions.map(absClassFunction => {
val scalaDoc = generateAPIDocFromBackend(absClassFunction, false)
if (isSymbol) {
val defBody = s"def ${absClassFunction.name}(name : String = null, attr : Map[String, String] = null)(args : org.apache.mxnet.Symbol*)(kwargs : Map[String, Any] = null): org.apache.mxnet.Symbol"
s"$scalaDoc\n$defBody"
} else {
val defBodyWithKwargs = s"def ${absClassFunction.name}(kwargs: Map[String, Any] = null)(args: Any*) : org.apache.mxnet.NDArrayFuncReturn"
val defBody = s"def ${absClassFunction.name}(args: Any*) : org.apache.mxnet.NDArrayFuncReturn"
s"$scalaDoc\n$defBodyWithKwargs\n$scalaDoc\n$defBody"
}
})
def nonTypeSafeClassGen(FILE_PATH: String, isSymbol: Boolean): String = {
val absClassFunctions = buildFunctionList(isSymbol)
val absFuncs = absClassFunctions
.filterNot(_.name.startsWith("_"))
mdespriee marked this conversation as resolved.
Show resolved Hide resolved
mdespriee marked this conversation as resolved.
Show resolved Hide resolved
.map(absClassFunction => {
val scalaDoc = generateAPIDocFromBackend(absClassFunction, false)
if (isSymbol) {
val defBody =
s"def ${absClassFunction.name}(name : String = null, attr : Map[String, String] = null)" +
s"(args : org.apache.mxnet.Symbol*)(kwargs : Map[String, Any] = null): " +
s"org.apache.mxnet.Symbol"
s"$scalaDoc\n$defBody"
} else {
val defBodyWithKwargs = s"def ${absClassFunction.name}(kwargs: Map[String, Any] = null)" +
s"(args: Any*): " +
s"org.apache.mxnet.NDArrayFuncReturn"
val defBody = s"def ${absClassFunction.name}(args: Any*): " +
s"org.apache.mxnet.NDArrayFuncReturn"
s"$scalaDoc\n$defBodyWithKwargs\n$scalaDoc\n$defBody"
}
})
val packageName = if (isSymbol) "SymbolBase" else "NDArrayBase"
val apacheLicence = "/*\n* Licensed to the Apache Software Foundation (ASF) under one or more\n* contributor license agreements. See the NOTICE file distributed with\n* this work for additional information regarding copyright ownership.\n* The ASF licenses this file to You under the Apache License, Version 2.0\n* (the \"License\"); you may not use this file except in compliance with\n* the License. You may obtain a copy of the License at\n*\n* http://www.apache.org/licenses/LICENSE-2.0\n*\n* Unless required by applicable law or agreed to in writing, software\n* distributed under the License is distributed on an \"AS IS\" BASIS,\n* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n* See the License for the specific language governing permissions and\n* limitations under the License.\n*/\n"
writeFile(FILE_PATH, packageName, absFuncs)
}

def writeFile(FILE_PATH: String, packageName: String, body: Seq[String]): String = {
val apacheLicence =
"""/*
|* Licensed to the Apache Software Foundation (ASF) under one or more
|* contributor license agreements. See the NOTICE file distributed with
|* this work for additional information regarding copyright ownership.
|* The ASF licenses this file to You under the Apache License, Version 2.0
|* (the "License"); you may not use this file except in compliance with
|* the License. You may obtain a copy of the License at
|*
|* http://www.apache.org/licenses/LICENSE-2.0
|*
|* Unless required by applicable law or agreed to in writing, software
|* distributed under the License is distributed on an "AS IS" BASIS,
|* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|* See the License for the specific language governing permissions and
|* limitations under the License.
|*/
|""".stripMargin
val scalaStyle = "// scalastyle:off"
val packageDef = "package org.apache.mxnet"
val imports = "import org.apache.mxnet.annotation.Experimental"
val absClassDef = s"abstract class $packageName"
val finalStr = s"$apacheLicence\n$scalaStyle\n$packageDef\n$imports\n$absClassDef {\n${absFuncs.mkString("\n")}\n}"
import java.io._
val finalStr =
s"""$apacheLicence
|$scalaStyle
|$packageDef
|$imports
|$absClassDef {
|${body.mkString("\n")}
|}""".stripMargin
val pw = new PrintWriter(new File(FILE_PATH + s"$packageName.scala"))
pw.write(finalStr)
pw.close()
MD5Generator(finalStr)
}

// Generate ScalaDoc type
def generateAPIDocFromBackend(func : absClassFunction, withParam : Boolean = true) : String = {
def generateAPIDocFromBackend(func: absClassFunction, withParam: Boolean = true): String = {
val desc = ArrayBuffer[String]()
desc += " * <pre>"
func.desc.split("\n").foreach({ currStr =>
func.desc.split("\n").foreach({ currStr =>
desc += s" * $currStr"
})
desc += " * </pre>"
val params = func.listOfArgs.map({ absClassArg =>
val currArgName = absClassArg.argName match {
case "var" => "vari"
case "type" => "typeOf"
case _ => absClassArg.argName
}
s" * @param $currArgName\t\t${absClassArg.argDesc}"
s" * @param ${absClassArg.safeArgName}\t\t${absClassArg.argDesc}"
})
val returnType = s" * @return ${func.returnType}"
if (withParam) {
Expand All @@ -130,65 +161,23 @@ private[mxnet] object APIDocGenerator{
}
}

def generateAPISignature(func : absClassFunction, isSymbol : Boolean) : String = {
var argDef = ListBuffer[String]()
func.listOfArgs.foreach(absClassArg => {
val currArgName = absClassArg.argName match {
case "var" => "vari"
case "type" => "typeOf"
case _ => absClassArg.argName
}
if (absClassArg.isOptional) {
argDef += s"$currArgName : Option[${absClassArg.argType}] = None"
}
else {
argDef += s"$currArgName : ${absClassArg.argType}"
}
})
var returnType = func.returnType
def generateAPISignature(func: absClassFunction, isSymbol: Boolean): String = {
val argDef = ListBuffer[String]()

argDef ++= buildArgDefs(func)

if (isSymbol) {
argDef += "name : String = null"
argDef += "attr : Map[String, String] = null"
} else {
argDef += "out : Option[NDArray] = None"
returnType = "org.apache.mxnet.NDArrayFuncReturn"
}

val returnType = func.returnType

val experimentalTag = "@Experimental"
s"$experimentalTag\ndef ${func.name} (${argDef.mkString(", ")}) : $returnType"
}


// List and add all the atomic symbol functions to current module.
private def getSymbolNDArrayMethods(isSymbol : Boolean): List[absClassFunction] = {
val opNames = ListBuffer.empty[String]
val returnType = if (isSymbol) "Symbol" else "NDArray"
_LIB.mxListAllOpNames(opNames)
// TODO: Add '_linalg_', '_sparse_', '_image_' support
// TODO: Add Filter to the same location in case of refactor
opNames.map(opName => {
val opHandle = new RefLong
_LIB.nnGetOpHandle(opName, opHandle)
makeAtomicSymbolFunction(opHandle.value, opName, "org.apache.mxnet." + returnType)
}).toList.filterNot(_.name.startsWith("_"))
}

// Create an atomic symbol function by handle and function name.
private def makeAtomicSymbolFunction(handle: SymbolHandle, aliasName: String, returnType : String)
: absClassFunction = {
val name = new RefString
val desc = new RefString
val keyVarNumArgs = new RefString
val numArgs = new RefInt
val argNames = ListBuffer.empty[String]
val argTypes = ListBuffer.empty[String]
val argDescs = ListBuffer.empty[String]

_LIB.mxSymbolGetAtomicSymbolInfo(
handle, name, desc, numArgs, argNames, argTypes, argDescs, keyVarNumArgs)
val argList = argNames zip argTypes zip argDescs map { case ((argName, argType), argDesc) =>
val typeAndOption = CToScalaUtils.argumentCleaner(argName, argType, returnType)
new absClassArg(argName, typeAndOption._1, argDesc, typeAndOption._2)
}
new absClassFunction(aliasName, desc.value, argList.toList, returnType)
}
}
Loading