From e34c98c9daa45fd4737961546e1284064df441be Mon Sep 17 00:00:00 2001 From: Qing Date: Wed, 2 May 2018 15:45:46 -0700 Subject: [PATCH 1/5] Add new NDArray APIs --- .../scala/org/apache/mxnet/NDArrayMacro.scala | 96 +++++++++++++++---- 1 file changed, 76 insertions(+), 20 deletions(-) diff --git a/scala-package/macros/src/main/scala/org/apache/mxnet/NDArrayMacro.scala b/scala-package/macros/src/main/scala/org/apache/mxnet/NDArrayMacro.scala index 036b9ec47530..3abdadaf7476 100644 --- a/scala-package/macros/src/main/scala/org/apache/mxnet/NDArrayMacro.scala +++ b/scala-package/macros/src/main/scala/org/apache/mxnet/NDArrayMacro.scala @@ -30,7 +30,8 @@ private[mxnet] class AddNDArrayFunctions(isContrib: Boolean) extends StaticAnnot } private[mxnet] object NDArrayMacro { - case class NDArrayFunction(handle: NDArrayHandle) + case class NDArrayArg(argName: String, argType: String, isOptional : Boolean) + case class NDArrayFunction(name: String, listOfArgs: List[NDArrayArg]) // scalastyle:off havetype def addDefs(c: blackbox.Context)(annottees: c.Expr[Any]*) = { @@ -38,7 +39,7 @@ private[mxnet] object NDArrayMacro { } // scalastyle:off havetype - private val ndarrayFunctions: Map[String, NDArrayFunction] = initNDArrayModule() + private val ndarrayFunctions: List[NDArrayFunction] = initNDArrayModule() private def impl(c: blackbox.Context)(addSuper: Boolean, annottees: c.Expr[Any]*): c.Expr[Any] = { import c.universe._ @@ -48,21 +49,12 @@ private[mxnet] object NDArrayMacro { } val newNDArrayFunctions = { - if (isContrib) ndarrayFunctions.filter(_._1.startsWith("_contrib_")) - else ndarrayFunctions.filter(!_._1.startsWith("_contrib_")) + if (isContrib) ndarrayFunctions.filter(_.name.startsWith("_contrib_")) + else ndarrayFunctions.filter(!_.name.startsWith("_contrib_")) } - val functionDefs = newNDArrayFunctions flatMap { case (funcName, funcProp) => - val functionScope = { - if (isContrib) Modifiers() - else { - if (funcName.startsWith("_")) Modifiers(Flag.PRIVATE) else Modifiers() - } - } - val newName = { - if (isContrib) funcName.substring(funcName.indexOf("_contrib_") + "_contrib_".length()) - else funcName - } + val functionDefs = newNDArrayFunctions flatMap { NDArrayfunction => + val funcName = NDArrayfunction.name val termName = TermName(funcName) // It will generate definition something like, Seq( @@ -102,20 +94,80 @@ private[mxnet] object NDArrayMacro { result } + // Convert C++ Types to Scala Types + private def typeConversion(in : String, argType : String = "") : String = { + in match { + case "Shape(tuple)" | "ShapeorNone" => "org.apache.mxnet.Shape" + case "Symbol" | "NDArray" | "NDArray-or-Symbol" => "org.apache.mxnet.NDArray" + case "Symbol[]" | "NDArray[]" | "NDArray-or-Symbol[]" | "SymbolorSymbol[]" + => "Array[org.apache.mxnet.NDArray]" + case "float" | "real_t" | "floatorNone" => "org.apache.mxnet.Base.MXFloat" + case "int" | "intorNone" | "int(non-negative)" => "Int" + case "long" | "long(non-negative)" => "Long" + case "double" | "doubleorNone" => "Double" + case "string" => "String" + case "boolean" => "Boolean" + case "tupleof" | "tupleof" | "ptr" | "" => "Any" + case default => throw new IllegalArgumentException( + s"Invalid type for args: $default, $argType") + } + } + + + /** + * By default, the argType come from the C++ API is a description more than a single word + * For Example: + * , , + * The three field shown above do not usually come at the same time + * This function used the above format to determine if the argument is + * optional, what is it Scala type and possibly pass in a default value + * @param argType Raw arguement Type description + * @return (Scala_Type, isOptional) + */ + private def argumentCleaner(argType : String) : (String, Boolean) = { + val spaceRemoved = argType.replaceAll("\\s+", "") + var commaRemoved : Array[String] = new Array[String](0) + // Deal with the case e.g: stype : {'csr', 'default', 'row_sparse'} + if (spaceRemoved.charAt(0)== '{') { + val endIdx = spaceRemoved.indexOf('}') + commaRemoved = spaceRemoved.substring(endIdx + 1).split(",") + commaRemoved(0) = "string" + } else { + commaRemoved = spaceRemoved.split(",") + } + // Optional Field + if (commaRemoved.length >= 3) { + // arg: Type, optional, default = Null + require(commaRemoved(1).equals("optional")) + require(commaRemoved(2).startsWith("default=")) + (typeConversion(commaRemoved(0), argType), true) + } else if (commaRemoved.length == 2 || commaRemoved.length == 1) { + val tempType = typeConversion(commaRemoved(0), argType) + val tempOptional = tempType.equals("org.apache.mxnet.NDArray") + (tempType, tempOptional) + } else { + throw new IllegalArgumentException( + s"Unrecognized arg field: $argType, ${commaRemoved.length}") + } + + } + + // List and add all the atomic symbol functions to current module. - private def initNDArrayModule(): Map[String, NDArrayFunction] = { + private def initNDArrayModule(): List[NDArrayFunction] = { val opNames = ListBuffer.empty[String] _LIB.mxListAllOpNames(opNames) + // TODO: Add '_linalg_', '_sparse_', '_image_' support opNames.map(opName => { val opHandle = new RefLong _LIB.nnGetOpHandle(opName, opHandle) makeNDArrayFunction(opHandle.value, opName) - }).toMap + }).toList } // Create an atomic symbol function by handle and function name. private def makeNDArrayFunction(handle: NDArrayHandle, aliasName: String) - : (String, NDArrayFunction) = { + : NDArrayFunction = { val name = new RefString val desc = new RefString val keyVarNumArgs = new RefString @@ -136,10 +188,14 @@ private[mxnet] object NDArrayMacro { val docStr = s"$aliasName $realName\n${desc.value}\n\n$paramStr\n$extraDoc\n" // scalastyle:off println if (System.getenv("MXNET4J_PRINT_OP_DEF") != null - && System.getenv("MXNET4J_PRINT_OP_DEF").toLowerCase == "true") { + && System.getenv("MXNET4J_PRINT_OP_DEF").toLowerCase == "true") { println("NDArray function definition:\n" + docStr) } // scalastyle:on println - (aliasName, new NDArrayFunction(handle)) + val argList = argNames zip argTypes map { case (argName, argType) => + val typeAndOption = argumentCleaner(argType) + new NDArrayArg(argName, typeAndOption._1, typeAndOption._2) + } + new NDArrayFunction(aliasName, argList.toList) } } From 00871d9639a33693e3019c7dd1e524ff8a1914f7 Mon Sep 17 00:00:00 2001 From: Qing Date: Thu, 3 May 2018 13:10:03 -0700 Subject: [PATCH 2/5] Add NDArray APIs --- .../main/scala/org/apache/mxnet/NDArray.scala | 2 + .../scala/org/apache/mxnet/NDArrayAPI.scala | 20 ++++ .../scala/org/apache/mxnet/NDArrayMacro.scala | 94 ++++++++++++++++--- 3 files changed, 103 insertions(+), 13 deletions(-) create mode 100644 scala-package/core/src/main/scala/org/apache/mxnet/NDArrayAPI.scala diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/NDArray.scala b/scala-package/core/src/main/scala/org/apache/mxnet/NDArray.scala index 416f2d74e828..469107aa58cc 100644 --- a/scala-package/core/src/main/scala/org/apache/mxnet/NDArray.scala +++ b/scala-package/core/src/main/scala/org/apache/mxnet/NDArray.scala @@ -37,6 +37,8 @@ object NDArray { private val functions: Map[String, NDArrayFunction] = initNDArrayModule() + val api = NDArrayAPI + private def addDependency(froms: Array[NDArray], tos: Array[NDArray]): Unit = { froms.foreach { from => val weakRef = new WeakReference(from) diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/NDArrayAPI.scala b/scala-package/core/src/main/scala/org/apache/mxnet/NDArrayAPI.scala new file mode 100644 index 000000000000..f207b62024b1 --- /dev/null +++ b/scala-package/core/src/main/scala/org/apache/mxnet/NDArrayAPI.scala @@ -0,0 +1,20 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.mxnet +@AddNDArrayAPIs(false) +object NDArrayAPI { +} diff --git a/scala-package/macros/src/main/scala/org/apache/mxnet/NDArrayMacro.scala b/scala-package/macros/src/main/scala/org/apache/mxnet/NDArrayMacro.scala index 3abdadaf7476..53a70429c655 100644 --- a/scala-package/macros/src/main/scala/org/apache/mxnet/NDArrayMacro.scala +++ b/scala-package/macros/src/main/scala/org/apache/mxnet/NDArrayMacro.scala @@ -29,23 +29,32 @@ private[mxnet] class AddNDArrayFunctions(isContrib: Boolean) extends StaticAnnot private[mxnet] def macroTransform(annottees: Any*) = macro NDArrayMacro.addDefs } +private[mxnet] class AddNDArrayAPIs(isContrib: Boolean) extends StaticAnnotation { + private[mxnet] def macroTransform(annottees: Any*) = macro NDArrayMacro.addNewDefs +} + private[mxnet] object NDArrayMacro { case class NDArrayArg(argName: String, argType: String, isOptional : Boolean) case class NDArrayFunction(name: String, listOfArgs: List[NDArrayArg]) // scalastyle:off havetype def addDefs(c: blackbox.Context)(annottees: c.Expr[Any]*) = { - impl(c)(false, annottees: _*) + impl(c)(false, false, annottees: _*) + } + def addNewDefs(c: blackbox.Context)(annottees: c.Expr[Any]*) = { + impl(c)(false, true, annottees: _*) } // scalastyle:off havetype private val ndarrayFunctions: List[NDArrayFunction] = initNDArrayModule() - private def impl(c: blackbox.Context)(addSuper: Boolean, annottees: c.Expr[Any]*): c.Expr[Any] = { + private def impl(c: blackbox.Context)(addSuper: Boolean, + newAPI: Boolean, annottees: c.Expr[Any]*): c.Expr[Any] = { import c.universe._ val isContrib: Boolean = c.prefix.tree match { case q"new AddNDArrayFunctions($b)" => c.eval[Boolean](c.Expr(b)) + case q"new AddNDArrayAPIs($b)" => c.eval[Boolean](c.Expr(b)) } val newNDArrayFunctions = { @@ -53,20 +62,80 @@ private[mxnet] object NDArrayMacro { else ndarrayFunctions.filter(!_.name.startsWith("_contrib_")) } - val functionDefs = newNDArrayFunctions flatMap { NDArrayfunction => - val funcName = NDArrayfunction.name - val termName = TermName(funcName) - // It will generate definition something like, - Seq( + var functionDefs = List[Tree]() + if (!newAPI) { + functionDefs = newNDArrayFunctions flatMap { NDArrayfunction => + val funcName = NDArrayfunction.name + val termName = TermName(funcName) + if (!NDArrayfunction.name.startsWith("_") || NDArrayfunction.name.startsWith("_contrib_")) { + Seq( + // scalastyle:off + // def transpose(kwargs: Map[String, Any] = null)(args: Any*) + q"def $termName(kwargs: Map[String, Any] = null)(args: Any*) = {genericNDArrayFunctionInvoke($funcName, args, kwargs)}", + // def transpose(args: Any*) + q"def $termName(args: Any*) = {genericNDArrayFunctionInvoke($funcName, args, null)}" + // scalastyle:on + ) + } else { + // Default private + Seq( + // scalastyle:off + q"private def $termName(kwargs: Map[String, Any] = null)(args: Any*) = {genericNDArrayFunctionInvoke($funcName, args, kwargs)}", + q"private def $termName(args: Any*) = {genericNDArrayFunctionInvoke($funcName, args, null)}" + // scalastyle:on + ) + } + } + } else { + functionDefs = newNDArrayFunctions map { ndarrayfunction => + + // Construct argument field + var argDef = ListBuffer[String]() + ndarrayfunction.listOfArgs.foreach(ndarrayarg => { + val currArgName = ndarrayarg.argName match { + case "var" => "vari" + case "type" => "typeOf" + case default => ndarrayarg.argName + } + if (ndarrayarg.isOptional) { + argDef += s"${currArgName} : Option[${ndarrayarg.argType}] = None" + } + else { + argDef += s"${currArgName} : ${ndarrayarg.argType}" + } + }) + argDef += "name : String = null" + argDef += "attr : Map[String, String] = null" + // Construct Implementation field + var impl = ListBuffer[String]() + impl += "val map = scala.collection.mutable.Map[String, Any]()" + ndarrayfunction.listOfArgs.foreach({ ndarrayarg => + // var is a special word used to define variable in Scala, + // need to changed to something else in order to make it work + val currArgName = ndarrayarg.argName match { + case "var" => "vari" + case "type" => "typeOf" + case default => ndarrayarg.argName + } + var base = "map(\"" + ndarrayarg.argName + "\") = " + currArgName + if (ndarrayarg.isOptional) { + base = "if (!" + currArgName + ".isEmpty)" + base + ".get" + } + impl += base + }) // scalastyle:off - // def transpose(kwargs: Map[String, Any] = null)(args: Any*) - q"def $termName(kwargs: Map[String, Any] = null)(args: Any*) = {genericNDArrayFunctionInvoke($funcName, args, kwargs)}", - // def transpose(args: Any*) - q"def $termName(args: Any*) = {genericNDArrayFunctionInvoke($funcName, args, null)}" + impl += "org.apache.mxnet.NDArray.genericNDArrayFunctionInvoke(\"" + ndarrayfunction.name + "\", null, map.toMap)" // scalastyle:on - ) + // Combine and build the function string + val returnType = "org.apache.mxnet.NDArray" + var finalStr = s"def ${ndarrayfunction.name}New" + finalStr += s" (${argDef.mkString(",")}) : $returnType" + finalStr += s" = {${impl.mkString("\n")}}" + c.parse(finalStr) + } } + val inputs = annottees.map(_.tree).toList // pattern match on the inputs val modDefs = inputs map { @@ -157,7 +226,6 @@ private[mxnet] object NDArrayMacro { private def initNDArrayModule(): List[NDArrayFunction] = { val opNames = ListBuffer.empty[String] _LIB.mxListAllOpNames(opNames) - // TODO: Add '_linalg_', '_sparse_', '_image_' support opNames.map(opName => { val opHandle = new RefLong _LIB.nnGetOpHandle(opName, opHandle) From 9fd650561cedb5fa306c883d4bcac883a6178dae Mon Sep 17 00:00:00 2001 From: Qing Date: Fri, 11 May 2018 10:12:53 -0700 Subject: [PATCH 3/5] change the impl into individual functions and add comments --- .../scala/org/apache/mxnet/NDArrayAPI.scala | 4 + .../scala/org/apache/mxnet/NDArrayMacro.scala | 140 ++++++++++-------- 2 files changed, 82 insertions(+), 62 deletions(-) diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/NDArrayAPI.scala b/scala-package/core/src/main/scala/org/apache/mxnet/NDArrayAPI.scala index f207b62024b1..d234ac66bdd8 100644 --- a/scala-package/core/src/main/scala/org/apache/mxnet/NDArrayAPI.scala +++ b/scala-package/core/src/main/scala/org/apache/mxnet/NDArrayAPI.scala @@ -16,5 +16,9 @@ */ package org.apache.mxnet @AddNDArrayAPIs(false) +/** + * typesafe NDArray API: NDArray.api._ + * Main code will be generated during compile time through Macros + */ object NDArrayAPI { } diff --git a/scala-package/macros/src/main/scala/org/apache/mxnet/NDArrayMacro.scala b/scala-package/macros/src/main/scala/org/apache/mxnet/NDArrayMacro.scala index 53a70429c655..f9a133963bd1 100644 --- a/scala-package/macros/src/main/scala/org/apache/mxnet/NDArrayMacro.scala +++ b/scala-package/macros/src/main/scala/org/apache/mxnet/NDArrayMacro.scala @@ -30,7 +30,7 @@ private[mxnet] class AddNDArrayFunctions(isContrib: Boolean) extends StaticAnnot } private[mxnet] class AddNDArrayAPIs(isContrib: Boolean) extends StaticAnnotation { - private[mxnet] def macroTransform(annottees: Any*) = macro NDArrayMacro.addNewDefs + private[mxnet] def macroTransform(annottees: Any*) = macro NDArrayMacro.typeSafeAPIDefs } private[mxnet] object NDArrayMacro { @@ -39,22 +39,20 @@ private[mxnet] object NDArrayMacro { // scalastyle:off havetype def addDefs(c: blackbox.Context)(annottees: c.Expr[Any]*) = { - impl(c)(false, false, annottees: _*) + impl(c)(annottees: _*) } - def addNewDefs(c: blackbox.Context)(annottees: c.Expr[Any]*) = { - impl(c)(false, true, annottees: _*) + def typeSafeAPIDefs(c: blackbox.Context)(annottees: c.Expr[Any]*) = { + typeSafeAPIImpl(c)(annottees: _*) } // scalastyle:off havetype private val ndarrayFunctions: List[NDArrayFunction] = initNDArrayModule() - private def impl(c: blackbox.Context)(addSuper: Boolean, - newAPI: Boolean, annottees: c.Expr[Any]*): c.Expr[Any] = { + private def impl(c: blackbox.Context)(annottees: c.Expr[Any]*): c.Expr[Any] = { import c.universe._ val isContrib: Boolean = c.prefix.tree match { case q"new AddNDArrayFunctions($b)" => c.eval[Boolean](c.Expr(b)) - case q"new AddNDArrayAPIs($b)" => c.eval[Boolean](c.Expr(b)) } val newNDArrayFunctions = { @@ -62,87 +60,104 @@ private[mxnet] object NDArrayMacro { else ndarrayFunctions.filter(!_.name.startsWith("_contrib_")) } - var functionDefs = List[Tree]() - if (!newAPI) { - functionDefs = newNDArrayFunctions flatMap { NDArrayfunction => + val functionDefs = newNDArrayFunctions flatMap { NDArrayfunction => val funcName = NDArrayfunction.name val termName = TermName(funcName) if (!NDArrayfunction.name.startsWith("_") || NDArrayfunction.name.startsWith("_contrib_")) { Seq( // scalastyle:off // def transpose(kwargs: Map[String, Any] = null)(args: Any*) - q"def $termName(kwargs: Map[String, Any] = null)(args: Any*) = {genericNDArrayFunctionInvoke($funcName, args, kwargs)}", + q"def $termName(kwargs: Map[String, Any] = null)(args: Any*) = {genericNDArrayFunctionInvoke($funcName, args, kwargs)}".asInstanceOf[DefDef], // def transpose(args: Any*) - q"def $termName(args: Any*) = {genericNDArrayFunctionInvoke($funcName, args, null)}" + q"def $termName(args: Any*) = {genericNDArrayFunctionInvoke($funcName, args, null)}".asInstanceOf[DefDef] // scalastyle:on ) } else { // Default private Seq( // scalastyle:off - q"private def $termName(kwargs: Map[String, Any] = null)(args: Any*) = {genericNDArrayFunctionInvoke($funcName, args, kwargs)}", - q"private def $termName(args: Any*) = {genericNDArrayFunctionInvoke($funcName, args, null)}" + q"private def $termName(kwargs: Map[String, Any] = null)(args: Any*) = {genericNDArrayFunctionInvoke($funcName, args, kwargs)}".asInstanceOf[DefDef], + q"private def $termName(args: Any*) = {genericNDArrayFunctionInvoke($funcName, args, null)}".asInstanceOf[DefDef] // scalastyle:on ) } } - } else { - functionDefs = newNDArrayFunctions map { ndarrayfunction => - // Construct argument field - var argDef = ListBuffer[String]() - ndarrayfunction.listOfArgs.foreach(ndarrayarg => { - val currArgName = ndarrayarg.argName match { - case "var" => "vari" - case "type" => "typeOf" - case default => ndarrayarg.argName - } - if (ndarrayarg.isOptional) { - argDef += s"${currArgName} : Option[${ndarrayarg.argType}] = None" - } - else { - argDef += s"${currArgName} : ${ndarrayarg.argType}" - } - }) - argDef += "name : String = null" - argDef += "attr : Map[String, String] = null" - // Construct Implementation field - var impl = ListBuffer[String]() - impl += "val map = scala.collection.mutable.Map[String, Any]()" - ndarrayfunction.listOfArgs.foreach({ ndarrayarg => - // var is a special word used to define variable in Scala, - // need to changed to something else in order to make it work - val currArgName = ndarrayarg.argName match { - case "var" => "vari" - case "type" => "typeOf" - case default => ndarrayarg.argName - } - var base = "map(\"" + ndarrayarg.argName + "\") = " + currArgName - if (ndarrayarg.isOptional) { - base = "if (!" + currArgName + ".isEmpty)" + base + ".get" - } - impl += base - }) - // scalastyle:off - impl += "org.apache.mxnet.NDArray.genericNDArrayFunctionInvoke(\"" + ndarrayfunction.name + "\", null, map.toMap)" - // scalastyle:on - // Combine and build the function string - val returnType = "org.apache.mxnet.NDArray" - var finalStr = s"def ${ndarrayfunction.name}New" - finalStr += s" (${argDef.mkString(",")}) : $returnType" - finalStr += s" = {${impl.mkString("\n")}}" - c.parse(finalStr) - } + structGeneration(c)(functionDefs, annottees : _*) + } + + private def typeSafeAPIImpl(c: blackbox.Context)(annottees: c.Expr[Any]*) : c.Expr[Any] = { + import c.universe._ + + val isContrib: Boolean = c.prefix.tree match { + case q"new AddNDArrayAPIs($b)" => c.eval[Boolean](c.Expr(b)) + } + val newNDArrayFunctions = { + if (isContrib) ndarrayFunctions.filter(_.name.startsWith("_contrib_")) + else ndarrayFunctions.filter(!_.name.startsWith("_contrib_")) + } + + val functionDefs = newNDArrayFunctions map { ndarrayfunction => + + // Construct argument field + var argDef = ListBuffer[String]() + ndarrayfunction.listOfArgs.foreach(ndarrayarg => { + val currArgName = ndarrayarg.argName match { + case "var" => "vari" + case "type" => "typeOf" + case default => ndarrayarg.argName + } + if (ndarrayarg.isOptional) { + argDef += s"${currArgName} : Option[${ndarrayarg.argType}] = None" + } + else { + argDef += s"${currArgName} : ${ndarrayarg.argType}" + } + }) + argDef += "name : String = null" + argDef += "attr : Map[String, String] = null" + // Construct Implementation field + var impl = ListBuffer[String]() + impl += "val map = scala.collection.mutable.Map[String, Any]()" + ndarrayfunction.listOfArgs.foreach({ ndarrayarg => + // var is a special word used to define variable in Scala, + // need to changed to something else in order to make it work + val currArgName = ndarrayarg.argName match { + case "var" => "vari" + case "type" => "typeOf" + case default => ndarrayarg.argName + } + var base = "map(\"" + ndarrayarg.argName + "\") = " + currArgName + if (ndarrayarg.isOptional) { + base = "if (!" + currArgName + ".isEmpty)" + base + ".get" + } + impl += base + }) + // scalastyle:off + impl += "org.apache.mxnet.NDArray.genericNDArrayFunctionInvoke(\"" + ndarrayfunction.name + "\", null, map.toMap)" + // scalastyle:on + // Combine and build the function string + val returnType = "org.apache.mxnet.NDArray" + var finalStr = s"def ${ndarrayfunction.name}New" + finalStr += s" (${argDef.mkString(",")}) : $returnType" + finalStr += s" = {${impl.mkString("\n")}}" + c.parse(finalStr).asInstanceOf[DefDef] } + structGeneration(c)(functionDefs, annottees : _*) + } + private def structGeneration(c: blackbox.Context) + (funcDef : List[c.universe.DefDef], annottees: c.Expr[Any]*) + : c.Expr[Any] = { + import c.universe._ val inputs = annottees.map(_.tree).toList // pattern match on the inputs val modDefs = inputs map { case ClassDef(mods, name, something, template) => val q = template match { case Template(superMaybe, emptyValDef, defs) => - Template(superMaybe, emptyValDef, defs ++ functionDefs) + Template(superMaybe, emptyValDef, defs ++ funcDef) case ex => throw new IllegalArgumentException(s"Invalid template: $ex") } @@ -150,7 +165,7 @@ private[mxnet] object NDArrayMacro { case ModuleDef(mods, name, template) => val q = template match { case Template(superMaybe, emptyValDef, defs) => - Template(superMaybe, emptyValDef, defs ++ functionDefs) + Template(superMaybe, emptyValDef, defs ++ funcDef) case ex => throw new IllegalArgumentException(s"Invalid template: $ex") } @@ -163,6 +178,7 @@ private[mxnet] object NDArrayMacro { result } + // Convert C++ Types to Scala Types private def typeConversion(in : String, argType : String = "") : String = { in match { From 846e88dffa2abfda9230b68c3b4dc575cf54b2ff Mon Sep 17 00:00:00 2001 From: Qing Date: Tue, 22 May 2018 09:40:50 -0700 Subject: [PATCH 4/5] Quick fix on redudant code --- .../scala/org/apache/mxnet/NDArrayMacro.scala | 25 ++++++------------- 1 file changed, 8 insertions(+), 17 deletions(-) diff --git a/scala-package/macros/src/main/scala/org/apache/mxnet/NDArrayMacro.scala b/scala-package/macros/src/main/scala/org/apache/mxnet/NDArrayMacro.scala index f9a133963bd1..d1ce74c54dbc 100644 --- a/scala-package/macros/src/main/scala/org/apache/mxnet/NDArrayMacro.scala +++ b/scala-package/macros/src/main/scala/org/apache/mxnet/NDArrayMacro.scala @@ -66,9 +66,9 @@ private[mxnet] object NDArrayMacro { if (!NDArrayfunction.name.startsWith("_") || NDArrayfunction.name.startsWith("_contrib_")) { Seq( // scalastyle:off - // def transpose(kwargs: Map[String, Any] = null)(args: Any*) + // e.g def transpose(kwargs: Map[String, Any] = null)(args: Any*) q"def $termName(kwargs: Map[String, Any] = null)(args: Any*) = {genericNDArrayFunctionInvoke($funcName, args, kwargs)}".asInstanceOf[DefDef], - // def transpose(args: Any*) + // e.g def transpose(args: Any*) q"def $termName(args: Any*) = {genericNDArrayFunctionInvoke($funcName, args, null)}".asInstanceOf[DefDef] // scalastyle:on ) @@ -101,21 +101,6 @@ private[mxnet] object NDArrayMacro { // Construct argument field var argDef = ListBuffer[String]() - ndarrayfunction.listOfArgs.foreach(ndarrayarg => { - val currArgName = ndarrayarg.argName match { - case "var" => "vari" - case "type" => "typeOf" - case default => ndarrayarg.argName - } - if (ndarrayarg.isOptional) { - argDef += s"${currArgName} : Option[${ndarrayarg.argType}] = None" - } - else { - argDef += s"${currArgName} : ${ndarrayarg.argType}" - } - }) - argDef += "name : String = null" - argDef += "attr : Map[String, String] = null" // Construct Implementation field var impl = ListBuffer[String]() impl += "val map = scala.collection.mutable.Map[String, Any]()" @@ -127,6 +112,12 @@ private[mxnet] object NDArrayMacro { case "type" => "typeOf" case default => ndarrayarg.argName } + if (ndarrayarg.isOptional) { + argDef += s"${currArgName} : Option[${ndarrayarg.argType}] = None" + } + else { + argDef += s"${currArgName} : ${ndarrayarg.argType}" + } var base = "map(\"" + ndarrayarg.argName + "\") = " + currArgName if (ndarrayarg.isOptional) { base = "if (!" + currArgName + ".isEmpty)" + base + ".get" From f35893a61969993f0f9080ecda360be3f5df1185 Mon Sep 17 00:00:00 2001 From: Qing Date: Tue, 22 May 2018 17:20:26 -0700 Subject: [PATCH 5/5] Change in Sync --- .../macros/src/main/scala/org/apache/mxnet/NDArrayMacro.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scala-package/macros/src/main/scala/org/apache/mxnet/NDArrayMacro.scala b/scala-package/macros/src/main/scala/org/apache/mxnet/NDArrayMacro.scala index d1ce74c54dbc..bbe786f5a0af 100644 --- a/scala-package/macros/src/main/scala/org/apache/mxnet/NDArrayMacro.scala +++ b/scala-package/macros/src/main/scala/org/apache/mxnet/NDArrayMacro.scala @@ -182,7 +182,7 @@ private[mxnet] object NDArrayMacro { case "long" | "long(non-negative)" => "Long" case "double" | "doubleorNone" => "Double" case "string" => "String" - case "boolean" => "Boolean" + case "boolean" | "booleanorNone" => "Boolean" case "tupleof" | "tupleof" | "ptr" | "" => "Any" case default => throw new IllegalArgumentException( s"Invalid type for args: $default, $argType")