diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/NDArrayAPI.scala b/scala-package/core/src/main/scala/org/apache/mxnet/NDArrayAPI.scala index f207b62024b1..d234ac66bdd8 100644 --- a/scala-package/core/src/main/scala/org/apache/mxnet/NDArrayAPI.scala +++ b/scala-package/core/src/main/scala/org/apache/mxnet/NDArrayAPI.scala @@ -16,5 +16,9 @@ */ package org.apache.mxnet @AddNDArrayAPIs(false) +/** + * typesafe NDArray API: NDArray.api._ + * Main code will be generated during compile time through Macros + */ object NDArrayAPI { } diff --git a/scala-package/macros/src/main/scala/org/apache/mxnet/NDArrayMacro.scala b/scala-package/macros/src/main/scala/org/apache/mxnet/NDArrayMacro.scala index 53a70429c655..f9a133963bd1 100644 --- a/scala-package/macros/src/main/scala/org/apache/mxnet/NDArrayMacro.scala +++ b/scala-package/macros/src/main/scala/org/apache/mxnet/NDArrayMacro.scala @@ -30,7 +30,7 @@ private[mxnet] class AddNDArrayFunctions(isContrib: Boolean) extends StaticAnnot } private[mxnet] class AddNDArrayAPIs(isContrib: Boolean) extends StaticAnnotation { - private[mxnet] def macroTransform(annottees: Any*) = macro NDArrayMacro.addNewDefs + private[mxnet] def macroTransform(annottees: Any*) = macro NDArrayMacro.typeSafeAPIDefs } private[mxnet] object NDArrayMacro { @@ -39,22 +39,20 @@ private[mxnet] object NDArrayMacro { // scalastyle:off havetype def addDefs(c: blackbox.Context)(annottees: c.Expr[Any]*) = { - impl(c)(false, false, annottees: _*) + impl(c)(annottees: _*) } - def addNewDefs(c: blackbox.Context)(annottees: c.Expr[Any]*) = { - impl(c)(false, true, annottees: _*) + def typeSafeAPIDefs(c: blackbox.Context)(annottees: c.Expr[Any]*) = { + typeSafeAPIImpl(c)(annottees: _*) } // scalastyle:off havetype private val ndarrayFunctions: List[NDArrayFunction] = initNDArrayModule() - private def impl(c: blackbox.Context)(addSuper: Boolean, - newAPI: Boolean, annottees: c.Expr[Any]*): c.Expr[Any] = { + private def impl(c: blackbox.Context)(annottees: c.Expr[Any]*): c.Expr[Any] = { import c.universe._ val isContrib: Boolean = c.prefix.tree match { case q"new AddNDArrayFunctions($b)" => c.eval[Boolean](c.Expr(b)) - case q"new AddNDArrayAPIs($b)" => c.eval[Boolean](c.Expr(b)) } val newNDArrayFunctions = { @@ -62,87 +60,104 @@ private[mxnet] object NDArrayMacro { else ndarrayFunctions.filter(!_.name.startsWith("_contrib_")) } - var functionDefs = List[Tree]() - if (!newAPI) { - functionDefs = newNDArrayFunctions flatMap { NDArrayfunction => + val functionDefs = newNDArrayFunctions flatMap { NDArrayfunction => val funcName = NDArrayfunction.name val termName = TermName(funcName) if (!NDArrayfunction.name.startsWith("_") || NDArrayfunction.name.startsWith("_contrib_")) { Seq( // scalastyle:off // def transpose(kwargs: Map[String, Any] = null)(args: Any*) - q"def $termName(kwargs: Map[String, Any] = null)(args: Any*) = {genericNDArrayFunctionInvoke($funcName, args, kwargs)}", + q"def $termName(kwargs: Map[String, Any] = null)(args: Any*) = {genericNDArrayFunctionInvoke($funcName, args, kwargs)}".asInstanceOf[DefDef], // def transpose(args: Any*) - q"def $termName(args: Any*) = {genericNDArrayFunctionInvoke($funcName, args, null)}" + q"def $termName(args: Any*) = {genericNDArrayFunctionInvoke($funcName, args, null)}".asInstanceOf[DefDef] // scalastyle:on ) } else { // Default private Seq( // scalastyle:off - q"private def $termName(kwargs: Map[String, Any] = null)(args: Any*) = {genericNDArrayFunctionInvoke($funcName, args, kwargs)}", - q"private def $termName(args: Any*) = {genericNDArrayFunctionInvoke($funcName, args, null)}" + q"private def $termName(kwargs: Map[String, Any] = null)(args: Any*) = {genericNDArrayFunctionInvoke($funcName, args, kwargs)}".asInstanceOf[DefDef], + q"private def $termName(args: Any*) = {genericNDArrayFunctionInvoke($funcName, args, null)}".asInstanceOf[DefDef] // scalastyle:on ) } } - } else { - functionDefs = newNDArrayFunctions map { ndarrayfunction => - // Construct argument field - var argDef = ListBuffer[String]() - ndarrayfunction.listOfArgs.foreach(ndarrayarg => { - val currArgName = ndarrayarg.argName match { - case "var" => "vari" - case "type" => "typeOf" - case default => ndarrayarg.argName - } - if (ndarrayarg.isOptional) { - argDef += s"${currArgName} : Option[${ndarrayarg.argType}] = None" - } - else { - argDef += s"${currArgName} : ${ndarrayarg.argType}" - } - }) - argDef += "name : String = null" - argDef += "attr : Map[String, String] = null" - // Construct Implementation field - var impl = ListBuffer[String]() - impl += "val map = scala.collection.mutable.Map[String, Any]()" - ndarrayfunction.listOfArgs.foreach({ ndarrayarg => - // var is a special word used to define variable in Scala, - // need to changed to something else in order to make it work - val currArgName = ndarrayarg.argName match { - case "var" => "vari" - case "type" => "typeOf" - case default => ndarrayarg.argName - } - var base = "map(\"" + ndarrayarg.argName + "\") = " + currArgName - if (ndarrayarg.isOptional) { - base = "if (!" + currArgName + ".isEmpty)" + base + ".get" - } - impl += base - }) - // scalastyle:off - impl += "org.apache.mxnet.NDArray.genericNDArrayFunctionInvoke(\"" + ndarrayfunction.name + "\", null, map.toMap)" - // scalastyle:on - // Combine and build the function string - val returnType = "org.apache.mxnet.NDArray" - var finalStr = s"def ${ndarrayfunction.name}New" - finalStr += s" (${argDef.mkString(",")}) : $returnType" - finalStr += s" = {${impl.mkString("\n")}}" - c.parse(finalStr) - } + structGeneration(c)(functionDefs, annottees : _*) + } + + private def typeSafeAPIImpl(c: blackbox.Context)(annottees: c.Expr[Any]*) : c.Expr[Any] = { + import c.universe._ + + val isContrib: Boolean = c.prefix.tree match { + case q"new AddNDArrayAPIs($b)" => c.eval[Boolean](c.Expr(b)) + } + val newNDArrayFunctions = { + if (isContrib) ndarrayFunctions.filter(_.name.startsWith("_contrib_")) + else ndarrayFunctions.filter(!_.name.startsWith("_contrib_")) + } + + val functionDefs = newNDArrayFunctions map { ndarrayfunction => + + // Construct argument field + var argDef = ListBuffer[String]() + ndarrayfunction.listOfArgs.foreach(ndarrayarg => { + val currArgName = ndarrayarg.argName match { + case "var" => "vari" + case "type" => "typeOf" + case default => ndarrayarg.argName + } + if (ndarrayarg.isOptional) { + argDef += s"${currArgName} : Option[${ndarrayarg.argType}] = None" + } + else { + argDef += s"${currArgName} : ${ndarrayarg.argType}" + } + }) + argDef += "name : String = null" + argDef += "attr : Map[String, String] = null" + // Construct Implementation field + var impl = ListBuffer[String]() + impl += "val map = scala.collection.mutable.Map[String, Any]()" + ndarrayfunction.listOfArgs.foreach({ ndarrayarg => + // var is a special word used to define variable in Scala, + // need to changed to something else in order to make it work + val currArgName = ndarrayarg.argName match { + case "var" => "vari" + case "type" => "typeOf" + case default => ndarrayarg.argName + } + var base = "map(\"" + ndarrayarg.argName + "\") = " + currArgName + if (ndarrayarg.isOptional) { + base = "if (!" + currArgName + ".isEmpty)" + base + ".get" + } + impl += base + }) + // scalastyle:off + impl += "org.apache.mxnet.NDArray.genericNDArrayFunctionInvoke(\"" + ndarrayfunction.name + "\", null, map.toMap)" + // scalastyle:on + // Combine and build the function string + val returnType = "org.apache.mxnet.NDArray" + var finalStr = s"def ${ndarrayfunction.name}New" + finalStr += s" (${argDef.mkString(",")}) : $returnType" + finalStr += s" = {${impl.mkString("\n")}}" + c.parse(finalStr).asInstanceOf[DefDef] } + structGeneration(c)(functionDefs, annottees : _*) + } + private def structGeneration(c: blackbox.Context) + (funcDef : List[c.universe.DefDef], annottees: c.Expr[Any]*) + : c.Expr[Any] = { + import c.universe._ val inputs = annottees.map(_.tree).toList // pattern match on the inputs val modDefs = inputs map { case ClassDef(mods, name, something, template) => val q = template match { case Template(superMaybe, emptyValDef, defs) => - Template(superMaybe, emptyValDef, defs ++ functionDefs) + Template(superMaybe, emptyValDef, defs ++ funcDef) case ex => throw new IllegalArgumentException(s"Invalid template: $ex") } @@ -150,7 +165,7 @@ private[mxnet] object NDArrayMacro { case ModuleDef(mods, name, template) => val q = template match { case Template(superMaybe, emptyValDef, defs) => - Template(superMaybe, emptyValDef, defs ++ functionDefs) + Template(superMaybe, emptyValDef, defs ++ funcDef) case ex => throw new IllegalArgumentException(s"Invalid template: $ex") } @@ -163,6 +178,7 @@ private[mxnet] object NDArrayMacro { result } + // Convert C++ Types to Scala Types private def typeConversion(in : String, argType : String = "") : String = { in match {