Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
change the impl into individual functions and add comments
Browse files Browse the repository at this point in the history
  • Loading branch information
lanking520 committed May 11, 2018
1 parent fe01dff commit 44d93e0
Show file tree
Hide file tree
Showing 2 changed files with 82 additions and 62 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -16,5 +16,9 @@
*/
package org.apache.mxnet
@AddNDArrayAPIs(false)
/**
* typesafe NDArray API: NDArray.api._
* Main code will be generated during compile time through Macros
*/
object NDArrayAPI {
}
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ private[mxnet] class AddNDArrayFunctions(isContrib: Boolean) extends StaticAnnot
}

private[mxnet] class AddNDArrayAPIs(isContrib: Boolean) extends StaticAnnotation {
private[mxnet] def macroTransform(annottees: Any*) = macro NDArrayMacro.addNewDefs
private[mxnet] def macroTransform(annottees: Any*) = macro NDArrayMacro.typeSafeAPIDefs
}

private[mxnet] object NDArrayMacro {
Expand All @@ -39,118 +39,133 @@ private[mxnet] object NDArrayMacro {

// scalastyle:off havetype
def addDefs(c: blackbox.Context)(annottees: c.Expr[Any]*) = {
impl(c)(false, false, annottees: _*)
impl(c)(annottees: _*)
}
def addNewDefs(c: blackbox.Context)(annottees: c.Expr[Any]*) = {
impl(c)(false, true, annottees: _*)
def typeSafeAPIDefs(c: blackbox.Context)(annottees: c.Expr[Any]*) = {
typeSafeAPIImpl(c)(annottees: _*)
}
// scalastyle:off havetype

private val ndarrayFunctions: List[NDArrayFunction] = initNDArrayModule()

private def impl(c: blackbox.Context)(addSuper: Boolean,
newAPI: Boolean, annottees: c.Expr[Any]*): c.Expr[Any] = {
private def impl(c: blackbox.Context)(annottees: c.Expr[Any]*): c.Expr[Any] = {
import c.universe._

val isContrib: Boolean = c.prefix.tree match {
case q"new AddNDArrayFunctions($b)" => c.eval[Boolean](c.Expr(b))
case q"new AddNDArrayAPIs($b)" => c.eval[Boolean](c.Expr(b))
}

val newNDArrayFunctions = {
if (isContrib) ndarrayFunctions.filter(_.name.startsWith("_contrib_"))
else ndarrayFunctions.filter(!_.name.startsWith("_contrib_"))
}

var functionDefs = List[Tree]()
if (!newAPI) {
functionDefs = newNDArrayFunctions flatMap { NDArrayfunction =>
val functionDefs = newNDArrayFunctions flatMap { NDArrayfunction =>
val funcName = NDArrayfunction.name
val termName = TermName(funcName)
if (!NDArrayfunction.name.startsWith("_") || NDArrayfunction.name.startsWith("_contrib_")) {
Seq(
// scalastyle:off
// def transpose(kwargs: Map[String, Any] = null)(args: Any*)
q"def $termName(kwargs: Map[String, Any] = null)(args: Any*) = {genericNDArrayFunctionInvoke($funcName, args, kwargs)}",
q"def $termName(kwargs: Map[String, Any] = null)(args: Any*) = {genericNDArrayFunctionInvoke($funcName, args, kwargs)}".asInstanceOf[DefDef],
// def transpose(args: Any*)
q"def $termName(args: Any*) = {genericNDArrayFunctionInvoke($funcName, args, null)}"
q"def $termName(args: Any*) = {genericNDArrayFunctionInvoke($funcName, args, null)}".asInstanceOf[DefDef]
// scalastyle:on
)
} else {
// Default private
Seq(
// scalastyle:off
q"private def $termName(kwargs: Map[String, Any] = null)(args: Any*) = {genericNDArrayFunctionInvoke($funcName, args, kwargs)}",
q"private def $termName(args: Any*) = {genericNDArrayFunctionInvoke($funcName, args, null)}"
q"private def $termName(kwargs: Map[String, Any] = null)(args: Any*) = {genericNDArrayFunctionInvoke($funcName, args, kwargs)}".asInstanceOf[DefDef],
q"private def $termName(args: Any*) = {genericNDArrayFunctionInvoke($funcName, args, null)}".asInstanceOf[DefDef]
// scalastyle:on
)
}
}
} else {
functionDefs = newNDArrayFunctions map { ndarrayfunction =>

// Construct argument field
var argDef = ListBuffer[String]()
ndarrayfunction.listOfArgs.foreach(ndarrayarg => {
val currArgName = ndarrayarg.argName match {
case "var" => "vari"
case "type" => "typeOf"
case default => ndarrayarg.argName
}
if (ndarrayarg.isOptional) {
argDef += s"${currArgName} : Option[${ndarrayarg.argType}] = None"
}
else {
argDef += s"${currArgName} : ${ndarrayarg.argType}"
}
})
argDef += "name : String = null"
argDef += "attr : Map[String, String] = null"
// Construct Implementation field
var impl = ListBuffer[String]()
impl += "val map = scala.collection.mutable.Map[String, Any]()"
ndarrayfunction.listOfArgs.foreach({ ndarrayarg =>
// var is a special word used to define variable in Scala,
// need to changed to something else in order to make it work
val currArgName = ndarrayarg.argName match {
case "var" => "vari"
case "type" => "typeOf"
case default => ndarrayarg.argName
}
var base = "map(\"" + ndarrayarg.argName + "\") = " + currArgName
if (ndarrayarg.isOptional) {
base = "if (!" + currArgName + ".isEmpty)" + base + ".get"
}
impl += base
})
// scalastyle:off
impl += "org.apache.mxnet.NDArray.genericNDArrayFunctionInvoke(\"" + ndarrayfunction.name + "\", null, map.toMap)"
// scalastyle:on
// Combine and build the function string
val returnType = "org.apache.mxnet.NDArray"
var finalStr = s"def ${ndarrayfunction.name}New"
finalStr += s" (${argDef.mkString(",")}) : $returnType"
finalStr += s" = {${impl.mkString("\n")}}"
c.parse(finalStr)
}
structGeneration(c)(functionDefs, annottees : _*)
}

private def typeSafeAPIImpl(c: blackbox.Context)(annottees: c.Expr[Any]*) : c.Expr[Any] = {
import c.universe._

val isContrib: Boolean = c.prefix.tree match {
case q"new AddNDArrayAPIs($b)" => c.eval[Boolean](c.Expr(b))
}
val newNDArrayFunctions = {
if (isContrib) ndarrayFunctions.filter(_.name.startsWith("_contrib_"))
else ndarrayFunctions.filter(!_.name.startsWith("_contrib_"))
}

val functionDefs = newNDArrayFunctions map { ndarrayfunction =>

// Construct argument field
var argDef = ListBuffer[String]()
ndarrayfunction.listOfArgs.foreach(ndarrayarg => {
val currArgName = ndarrayarg.argName match {
case "var" => "vari"
case "type" => "typeOf"
case default => ndarrayarg.argName
}
if (ndarrayarg.isOptional) {
argDef += s"${currArgName} : Option[${ndarrayarg.argType}] = None"
}
else {
argDef += s"${currArgName} : ${ndarrayarg.argType}"
}
})
argDef += "name : String = null"
argDef += "attr : Map[String, String] = null"
// Construct Implementation field
var impl = ListBuffer[String]()
impl += "val map = scala.collection.mutable.Map[String, Any]()"
ndarrayfunction.listOfArgs.foreach({ ndarrayarg =>
// var is a special word used to define variable in Scala,
// need to changed to something else in order to make it work
val currArgName = ndarrayarg.argName match {
case "var" => "vari"
case "type" => "typeOf"
case default => ndarrayarg.argName
}
var base = "map(\"" + ndarrayarg.argName + "\") = " + currArgName
if (ndarrayarg.isOptional) {
base = "if (!" + currArgName + ".isEmpty)" + base + ".get"
}
impl += base
})
// scalastyle:off
impl += "org.apache.mxnet.NDArray.genericNDArrayFunctionInvoke(\"" + ndarrayfunction.name + "\", null, map.toMap)"
// scalastyle:on
// Combine and build the function string
val returnType = "org.apache.mxnet.NDArray"
var finalStr = s"def ${ndarrayfunction.name}New"
finalStr += s" (${argDef.mkString(",")}) : $returnType"
finalStr += s" = {${impl.mkString("\n")}}"
c.parse(finalStr).asInstanceOf[DefDef]
}

structGeneration(c)(functionDefs, annottees : _*)
}

private def structGeneration(c: blackbox.Context)
(funcDef : List[c.universe.DefDef], annottees: c.Expr[Any]*)
: c.Expr[Any] = {
import c.universe._
val inputs = annottees.map(_.tree).toList
// pattern match on the inputs
val modDefs = inputs map {
case ClassDef(mods, name, something, template) =>
val q = template match {
case Template(superMaybe, emptyValDef, defs) =>
Template(superMaybe, emptyValDef, defs ++ functionDefs)
Template(superMaybe, emptyValDef, defs ++ funcDef)
case ex =>
throw new IllegalArgumentException(s"Invalid template: $ex")
}
ClassDef(mods, name, something, q)
case ModuleDef(mods, name, template) =>
val q = template match {
case Template(superMaybe, emptyValDef, defs) =>
Template(superMaybe, emptyValDef, defs ++ functionDefs)
Template(superMaybe, emptyValDef, defs ++ funcDef)
case ex =>
throw new IllegalArgumentException(s"Invalid template: $ex")
}
Expand All @@ -163,6 +178,7 @@ private[mxnet] object NDArrayMacro {
result
}


// Convert C++ Types to Scala Types
private def typeConversion(in : String, argType : String = "") : String = {
in match {
Expand Down

0 comments on commit 44d93e0

Please sign in to comment.