From 23082ebbaeafafbb6b375c06ac060c22875fdf8f Mon Sep 17 00:00:00 2001 From: Qing Date: Mon, 23 Apr 2018 15:36:00 -0700 Subject: [PATCH 01/26] Simplfied current Macros impl to Quasiquote --- .../scala/org/apache/mxnet/SymbolMacro.scala | 64 +++---------------- 1 file changed, 10 insertions(+), 54 deletions(-) diff --git a/scala-package/macros/src/main/scala/org/apache/mxnet/SymbolMacro.scala b/scala-package/macros/src/main/scala/org/apache/mxnet/SymbolMacro.scala index b6ddaafc7ad7..1a240091ab94 100644 --- a/scala-package/macros/src/main/scala/org/apache/mxnet/SymbolMacro.scala +++ b/scala-package/macros/src/main/scala/org/apache/mxnet/SymbolMacro.scala @@ -52,63 +52,19 @@ private[mxnet] object SymbolImplMacros { else symbolFunctions.filter(!_._1.startsWith("_contrib_")) } - val AST_TYPE_MAP_STRING_ANY = AppliedTypeTree(Ident(TypeName("Map")), - List(Ident(TypeName("String")), Ident(TypeName("Any")))) - val AST_TYPE_MAP_STRING_STRING = AppliedTypeTree(Ident(TypeName("Map")), - List(Ident(TypeName("String")), Ident(TypeName("String")))) - val AST_TYPE_SYMBOL_VARARG = AppliedTypeTree( - Select( - Select(Ident(termNames.ROOTPKG), TermName("scala")), - TypeName("") - ), - List(Select(Select(Select( - Ident(TermName("org")), TermName("apache")), TermName("mxnet")), TypeName("Symbol"))) - ) - - val functionDefs = newSymbolFunctions map { case (funcName, funcProp) => - val functionScope = { - if (isContrib) Modifiers() - else { - if (funcName.startsWith("_")) Modifiers(Flag.PRIVATE) else Modifiers() - } - } - val newName = { - if (isContrib) funcName.substring(funcName.indexOf("_contrib_") + "_contrib_".length()) - else funcName - } + val functionDefs = newSymbolFunctions map { case (funcName, _) => + val tName = TermName(funcName) + q""" + def $tName(name : String = null, attr : Map[String, String] = null) + (args : org.apache.mxnet.Symbol*)(kwargs : Map[String, Any] = null) + : org.apache.mxnet.Symbol = { + createSymbolGeneral($funcName,name,attr,args,kwargs) + } + """ - // It will generate definition something like, - // def Concat(name: String = null, attr: Map[String, String] = null) - // (args: Symbol*)(kwargs: Map[String, Any] = null) - DefDef(functionScope, TermName(newName), List(), - List( - List( - ValDef(Modifiers(Flag.PARAM | Flag.DEFAULTPARAM), TermName("name"), - Ident(TypeName("String")), Literal(Constant(null))), - ValDef(Modifiers(Flag.PARAM | Flag.DEFAULTPARAM), TermName("attr"), - AST_TYPE_MAP_STRING_STRING, Literal(Constant(null))) - ), - List( - ValDef(Modifiers(), TermName("args"), AST_TYPE_SYMBOL_VARARG, EmptyTree) - ), - List( - ValDef(Modifiers(Flag.PARAM | Flag.DEFAULTPARAM), TermName("kwargs"), - AST_TYPE_MAP_STRING_ANY, Literal(Constant(null))) - ) - ), TypeTree(), - Apply( - Ident(TermName("createSymbolGeneral")), - List( - Literal(Constant(funcName)), - Ident(TermName("name")), - Ident(TermName("attr")), - Ident(TermName("args")), - Ident(TermName("kwargs")) - ) - ) - ) } + val inputs = annottees.map(_.tree).toList // pattern match on the inputs val modDefs = inputs map { From 9b74fc9af926f3d6dff7b866b538c9a3395ac44f Mon Sep 17 00:00:00 2001 From: Qing Date: Mon, 23 Apr 2018 17:31:55 -0700 Subject: [PATCH 02/26] Change the Symbol Function Field, add SymbolArg --- .../scala/org/apache/mxnet/SymbolMacro.scala | 73 ++++++++++++++++--- 1 file changed, 64 insertions(+), 9 deletions(-) diff --git a/scala-package/macros/src/main/scala/org/apache/mxnet/SymbolMacro.scala b/scala-package/macros/src/main/scala/org/apache/mxnet/SymbolMacro.scala index 1a240091ab94..36494bd480b8 100644 --- a/scala-package/macros/src/main/scala/org/apache/mxnet/SymbolMacro.scala +++ b/scala-package/macros/src/main/scala/org/apache/mxnet/SymbolMacro.scala @@ -30,7 +30,8 @@ private[mxnet] class AddSymbolFunctions(isContrib: Boolean) extends StaticAnnota } private[mxnet] object SymbolImplMacros { - case class SymbolFunction(handle: SymbolHandle, keyVarNumArgs: String) + case class SymbolArg(argName: String, argType: String, isOptional : Boolean) + case class SymbolFunction(name: String, listOfArgs: List[SymbolArg]) // scalastyle:off havetype def addDefs(c: blackbox.Context)(annottees: c.Expr[Any]*) = { @@ -38,7 +39,7 @@ private[mxnet] object SymbolImplMacros { } // scalastyle:off havetype - private val symbolFunctions: Map[String, SymbolFunction] = initSymbolModule() + private val symbolFunctions: List[SymbolFunction] = initSymbolModule() private def impl(c: blackbox.Context)(addSuper: Boolean, annottees: c.Expr[Any]*): c.Expr[Any] = { import c.universe._ @@ -48,11 +49,12 @@ private[mxnet] object SymbolImplMacros { } val newSymbolFunctions = { - if (isContrib) symbolFunctions.filter(_._1.startsWith("_contrib_")) - else symbolFunctions.filter(!_._1.startsWith("_contrib_")) + if (isContrib) symbolFunctions.filter(_.name.startsWith("_contrib_")) + else symbolFunctions.filter(!_.name.startsWith("_contrib_")) } - val functionDefs = newSymbolFunctions map { case (funcName, _) => + val functionDefs = newSymbolFunctions map { symbolfunction => + val funcName = symbolfunction.name val tName = TermName(funcName) q""" def $tName(name : String = null, attr : Map[String, String] = null) @@ -61,7 +63,11 @@ private[mxnet] object SymbolImplMacros { createSymbolGeneral($funcName,name,attr,args,kwargs) } """ + } + val newFunctionDefs = newSymbolFunctions map { symbolfunction => + // TODO: Implement the codeGen + null } @@ -92,20 +98,65 @@ private[mxnet] object SymbolImplMacros { result } + // Convert C++ Types to Scala Types + private def typeConversion(in : String) : String = { + in match { + case "Shape(tuple)" | "ShapeorNone" => "Shape" + case "Symbol" | "NDArray" | "NDArray-or-Symbol" => "Symbol" + case "Symbol[]" | "NDArray[]" | "NDArray-or-Symbol[]" | "SymbolorSymbol[]" => "Array[Symbol]" + case "float" | "real_t" => "MXFloat" + case "int" | "intorNone" | "int(non-negative)" => "Int" + case "long" | "long(non-negative)" => "Long" + case "double" => "Double" + case "string" => "String" + case "boolean" => "Boolean" + case "tupleof" => "Any" + case default => throw new IllegalArgumentException(s"Invalid type for args: $default") + } + } + + private def argumentCleaner(argType : String) : (String, Boolean) = { + val spaceRemoved = argType.replaceAll("\\s+", "") + var commaRemoved : Array[String] = new Array[String](0) + // Deal with the case e.g: stype : {'csr', 'default', 'row_sparse'} + if (spaceRemoved.charAt(0)== '{') { + val endIdx = spaceRemoved.indexOf('}') + commaRemoved = spaceRemoved.substring(endIdx + 1).split(",") + // commaRemoved(0) = spaceRemoved.substring(0, endIdx+1) + commaRemoved(0) = "string" + } else { + commaRemoved = spaceRemoved.split(",") + } + // Optional Field + if (commaRemoved.length == 3) { + (typeConversion(commaRemoved(0)), true) + // TODO: Qing: do we set default value on our side? + // optionalField = " = " + conversion(typeConv, commaRemoved(2).split("=")(1)) + } else if (commaRemoved.length == 2) { + val tempType = typeConversion(argType) + val tempOptional = tempType.equals("Symbol") + (commaRemoved(0), tempOptional) + } else { + throw new IllegalArgumentException(s"Unrecognized arg field: $argType") + } + + } + + // List and add all the atomic symbol functions to current module. - private def initSymbolModule(): Map[String, SymbolFunction] = { + private def initSymbolModule(): List[SymbolFunction] = { val opNames = ListBuffer.empty[String] _LIB.mxListAllOpNames(opNames) opNames.map(opName => { val opHandle = new RefLong _LIB.nnGetOpHandle(opName, opHandle) makeAtomicSymbolFunction(opHandle.value, opName) - }).toMap + }).toList } // Create an atomic symbol function by handle and function name. private def makeAtomicSymbolFunction(handle: SymbolHandle, aliasName: String) - : (String, SymbolFunction) = { + : SymbolFunction = { val name = new RefString val desc = new RefString val keyVarNumArgs = new RefString @@ -130,6 +181,10 @@ private[mxnet] object SymbolImplMacros { println("Symbol function definition:\n" + docStr) } // scalastyle:on println - (aliasName, new SymbolFunction(handle, keyVarNumArgs.value)) + val argList = (argNames zip argTypes) map { case ((argName, argType)) => + val tup = argumentCleaner(argType) + new SymbolArg(argName, tup._1, tup._2) + } + new SymbolFunction(aliasName, argList.toList) } } From 2485146f0c37527272268ff155c333812cfe4258 Mon Sep 17 00:00:00 2001 From: Qing Date: Tue, 24 Apr 2018 09:30:07 -0700 Subject: [PATCH 03/26] Fix the Macros problem, disable the hidden function _ --- .../scala/org/apache/mxnet/SymbolMacro.scala | 22 ++++++++++--------- 1 file changed, 12 insertions(+), 10 deletions(-) diff --git a/scala-package/macros/src/main/scala/org/apache/mxnet/SymbolMacro.scala b/scala-package/macros/src/main/scala/org/apache/mxnet/SymbolMacro.scala index 36494bd480b8..5132fd06cc20 100644 --- a/scala-package/macros/src/main/scala/org/apache/mxnet/SymbolMacro.scala +++ b/scala-package/macros/src/main/scala/org/apache/mxnet/SymbolMacro.scala @@ -99,7 +99,7 @@ private[mxnet] object SymbolImplMacros { } // Convert C++ Types to Scala Types - private def typeConversion(in : String) : String = { + private def typeConversion(in : String, argType : String = "") : String = { in match { case "Shape(tuple)" | "ShapeorNone" => "Shape" case "Symbol" | "NDArray" | "NDArray-or-Symbol" => "Symbol" @@ -111,7 +111,8 @@ private[mxnet] object SymbolImplMacros { case "string" => "String" case "boolean" => "Boolean" case "tupleof" => "Any" - case default => throw new IllegalArgumentException(s"Invalid type for args: $default") + case default => throw new IllegalArgumentException( + s"Invalid type for args: $default, $argType") } } @@ -128,16 +129,17 @@ private[mxnet] object SymbolImplMacros { commaRemoved = spaceRemoved.split(",") } // Optional Field - if (commaRemoved.length == 3) { - (typeConversion(commaRemoved(0)), true) + if (commaRemoved.length >= 3) { + (typeConversion(commaRemoved(0), argType), true) // TODO: Qing: do we set default value on our side? // optionalField = " = " + conversion(typeConv, commaRemoved(2).split("=")(1)) - } else if (commaRemoved.length == 2) { - val tempType = typeConversion(argType) + } else if (commaRemoved.length == 2 || commaRemoved.length == 1) { + val tempType = typeConversion(commaRemoved(0), argType) val tempOptional = tempType.equals("Symbol") - (commaRemoved(0), tempOptional) + (tempType, tempOptional) } else { - throw new IllegalArgumentException(s"Unrecognized arg field: $argType") + throw new IllegalArgumentException( + s"Unrecognized arg field: $argType, ${commaRemoved.length}") } } @@ -147,7 +149,7 @@ private[mxnet] object SymbolImplMacros { private def initSymbolModule(): List[SymbolFunction] = { val opNames = ListBuffer.empty[String] _LIB.mxListAllOpNames(opNames) - opNames.map(opName => { + opNames.filter(!_.startsWith("_")).map(opName => { val opHandle = new RefLong _LIB.nnGetOpHandle(opName, opHandle) makeAtomicSymbolFunction(opHandle.value, opName) @@ -181,7 +183,7 @@ private[mxnet] object SymbolImplMacros { println("Symbol function definition:\n" + docStr) } // scalastyle:on println - val argList = (argNames zip argTypes) map { case ((argName, argType)) => + val argList = argNames zip argTypes map { case (argName, argType) => val tup = argumentCleaner(argType) new SymbolArg(argName, tup._1, tup._2) } From e61e13afea6634ffa4fc293a1ed4f6596ee487a1 Mon Sep 17 00:00:00 2001 From: Qing Date: Wed, 25 Apr 2018 13:32:06 -0700 Subject: [PATCH 04/26] Add Implementation for New API --- .../scala/org/apache/mxnet/SymbolMacro.scala | 49 ++++++++++++++++--- 1 file changed, 42 insertions(+), 7 deletions(-) diff --git a/scala-package/macros/src/main/scala/org/apache/mxnet/SymbolMacro.scala b/scala-package/macros/src/main/scala/org/apache/mxnet/SymbolMacro.scala index 5132fd06cc20..8ad86b4a9fb3 100644 --- a/scala-package/macros/src/main/scala/org/apache/mxnet/SymbolMacro.scala +++ b/scala-package/macros/src/main/scala/org/apache/mxnet/SymbolMacro.scala @@ -21,7 +21,6 @@ import scala.annotation.StaticAnnotation import scala.collection.mutable.ListBuffer import scala.language.experimental.macros import scala.reflect.macros.blackbox - import org.apache.mxnet.init.Base._ import org.apache.mxnet.utils.OperatorBuildUtils @@ -53,7 +52,7 @@ private[mxnet] object SymbolImplMacros { else symbolFunctions.filter(!_.name.startsWith("_contrib_")) } - val functionDefs = newSymbolFunctions map { symbolfunction => + var functionDefs = newSymbolFunctions map { symbolfunction => val funcName = symbolfunction.name val tName = TermName(funcName) q""" @@ -65,12 +64,46 @@ private[mxnet] object SymbolImplMacros { """ } - val newFunctionDefs = newSymbolFunctions map { symbolfunction => + + val newFunctionDefs : List[DefDef] = newSymbolFunctions map { symbolfunction => // TODO: Implement the codeGen - null + + var argDef = ListBuffer[String]() + symbolfunction.listOfArgs.foreach(symbolarg => { + val currArgName = if (symbolarg.argName.equals("var")) "vari" else symbolarg.argName + if (symbolarg.isOptional) { + argDef += s"${currArgName} : Option[${symbolarg.argType}] = None" + } + else { + argDef += s"${currArgName} : ${symbolarg.argType}" + } + }) + argDef += "name : String = null" + argDef += "attr : Map[String, String] = null" + + var impl = ListBuffer[String]() + impl += "val map = scala.collection.mutable.Map[String, Any]()" + symbolfunction.listOfArgs.foreach({ symbolarg => + val currArgName = if (symbolarg.argName.equals("var")) "vari" else symbolarg.argName + var base = "map(\"" + symbolarg.argName + "\") = " + currArgName + if (symbolarg.isOptional) { + base = "if (!" + currArgName + ".isEmpty)" + base + ".get" + } + impl += base + }) + impl += "createSymbolGeneral(\"" + symbolfunction.name + "\", name, attr, Seq(), map.toMap)" + + val returnType = "Symbol" + var finalStr = s"def ${symbolfunction.name}New" + finalStr += s" (${argDef.mkString(",")}) : Symbol" + finalStr += s" = {${impl.mkString("\n")}}" + c.parse(finalStr).asInstanceOf[DefDef] } + functionDefs = functionDefs ::: newFunctionDefs + + val inputs = annottees.map(_.tree).toList // pattern match on the inputs val modDefs = inputs map { @@ -103,7 +136,8 @@ private[mxnet] object SymbolImplMacros { in match { case "Shape(tuple)" | "ShapeorNone" => "Shape" case "Symbol" | "NDArray" | "NDArray-or-Symbol" => "Symbol" - case "Symbol[]" | "NDArray[]" | "NDArray-or-Symbol[]" | "SymbolorSymbol[]" => "Array[Symbol]" + case "Symbol[]" | "NDArray[]" | "NDArray-or-Symbol[]" | "SymbolorSymbol[]" + => "Array[Symbol]" case "float" | "real_t" => "MXFloat" case "int" | "intorNone" | "int(non-negative)" => "Int" case "long" | "long(non-negative)" => "Long" @@ -116,6 +150,7 @@ private[mxnet] object SymbolImplMacros { } } + private def argumentCleaner(argType : String) : (String, Boolean) = { val spaceRemoved = argType.replaceAll("\\s+", "") var commaRemoved : Array[String] = new Array[String](0) @@ -184,8 +219,8 @@ private[mxnet] object SymbolImplMacros { } // scalastyle:on println val argList = argNames zip argTypes map { case (argName, argType) => - val tup = argumentCleaner(argType) - new SymbolArg(argName, tup._1, tup._2) + val typeAndOption = argumentCleaner(argType) + new SymbolArg(argName, typeAndOption._1, typeAndOption._2) } new SymbolFunction(aliasName, argList.toList) } From ff4e24eab9fd55030ce593c8c7667cbfded9c8b8 Mon Sep 17 00:00:00 2001 From: Qing Date: Wed, 25 Apr 2018 15:47:30 -0700 Subject: [PATCH 05/26] Trigger CI From 704be92bac26d2b9f1698e4188ac4ba68e71658a Mon Sep 17 00:00:00 2001 From: Qing Date: Thu, 26 Apr 2018 10:28:45 -0700 Subject: [PATCH 06/26] Trigger CI From eae7bd625fe1faae492f87857989ce516dba78e7 Mon Sep 17 00:00:00 2001 From: Qing Date: Thu, 26 Apr 2018 13:32:52 -0700 Subject: [PATCH 07/26] Add examples and comments --- .../imclassification/TrainMnist.scala | 40 +++++++++---------- .../scala/org/apache/mxnet/SymbolMacro.scala | 9 ++--- 2 files changed, 22 insertions(+), 27 deletions(-) diff --git a/scala-package/examples/src/main/scala/org/apache/mxnetexamples/imclassification/TrainMnist.scala b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/imclassification/TrainMnist.scala index d1ec88d67c6b..699d3f6ad707 100644 --- a/scala-package/examples/src/main/scala/org/apache/mxnetexamples/imclassification/TrainMnist.scala +++ b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/imclassification/TrainMnist.scala @@ -30,40 +30,38 @@ object TrainMnist { // multi-layer perceptron def getMlp: Symbol = { val data = Symbol.Variable("data") - val fc1 = Symbol.FullyConnected(name = "fc1")()(Map("data" -> data, "num_hidden" -> 128)) - val act1 = Symbol.Activation(name = "relu1")()(Map("data" -> fc1, "act_type" -> "relu")) - val fc2 = Symbol.FullyConnected(name = "fc2")()(Map("data" -> act1, "num_hidden" -> 64)) - val act2 = Symbol.Activation(name = "relu2")()(Map("data" -> fc2, "act_type" -> "relu")) - val fc3 = Symbol.FullyConnected(name = "fc3")()(Map("data" -> act2, "num_hidden" -> 10)) - val mlp = Symbol.SoftmaxOutput(name = "softmax")()(Map("data" -> fc3)) + + val fc1 = Symbol.FullyConnectedNew(data = Some(data), num_hidden = 128, name = "fc1") + val act1 = Symbol.ActivationNew(data = Some(fc1), "relu", name = "relu") + val fc2 = Symbol.FullyConnectedNew(Some(act1), None, None, 64, name = "fc2") + val act2 = Symbol.ActivationNew(data = Some(fc2), "relu", name = "relu2") + val fc3 = Symbol.FullyConnectedNew(Some(act2), None, None, 10, name = "fc3") + val mlp = Symbol.SoftmaxOutputNew(name = "softmax", data = Some(fc3)) mlp } // LeCun, Yann, Leon Bottou, Yoshua Bengio, and Patrick // Haffner. "Gradient-based learning applied to document recognition." // Proceedings of the IEEE (1998) + def getLenet: Symbol = { val data = Symbol.Variable("data") // first conv - val conv1 = Symbol.Convolution()()( - Map("data" -> data, "kernel" -> "(5, 5)", "num_filter" -> 20)) - val tanh1 = Symbol.Activation()()(Map("data" -> conv1, "act_type" -> "tanh")) - val pool1 = Symbol.Pooling()()(Map("data" -> tanh1, "pool_type" -> "max", - "kernel" -> "(2, 2)", "stride" -> "(2, 2)")) + val conv1 = Symbol.ConvolutionNew(data = Some(data), kernel = Shape(5, 5), num_filter = 20) + val tanh1 = Symbol.tanhNew(data = Some(conv1)) + val pool1 = Symbol.PoolingNew(data = Some(tanh1), pool_type = Some("max"), kernel = Some(Shape(2, 2)), stride = Some(Shape(2, 2))) // second conv - val conv2 = Symbol.Convolution()()( - Map("data" -> pool1, "kernel" -> "(5, 5)", "num_filter" -> 50)) - val tanh2 = Symbol.Activation()()(Map("data" -> conv2, "act_type" -> "tanh")) - val pool2 = Symbol.Pooling()()(Map("data" -> tanh2, "pool_type" -> "max", - "kernel" -> "(2, 2)", "stride" -> "(2, 2)")) + val conv2 = Symbol.ConvolutionNew(data = Some(pool1), kernel = Shape(5, 5), num_filter = 50) + val tanh2 = Symbol.tanhNew(data = Some(conv2)) + val pool2 = Symbol.PoolingNew(data = Some(tanh2), pool_type = Some("max"), kernel = Some(Shape(2, 2)), stride = Some(Shape(2, 2))) // first fullc - val flatten = Symbol.Flatten()()(Map("data" -> pool2)) - val fc1 = Symbol.FullyConnected()()(Map("data" -> flatten, "num_hidden" -> 500)) - val tanh3 = Symbol.Activation()()(Map("data" -> fc1, "act_type" -> "tanh")) + val flatten = Symbol.FlattenNew(data = Some(pool2)) + val fc1 = Symbol.FullyConnectedNew(data = Some(flatten), num_hidden = 500) + val tanh3 = Symbol.tanhNew(data = Some(fc1)) // second fullc - val fc2 = Symbol.FullyConnected()()(Map("data" -> tanh3, "num_hidden" -> 10)) + val fc2 = Symbol.FullyConnectedNew(data = Some(tanh3), num_hidden = 10) // loss - val lenet = Symbol.SoftmaxOutput(name = "softmax")()(Map("data" -> fc2)) + val lenet = Symbol.SoftmaxOutputNew(name = "softmax", data = Some(fc2)) lenet } diff --git a/scala-package/macros/src/main/scala/org/apache/mxnet/SymbolMacro.scala b/scala-package/macros/src/main/scala/org/apache/mxnet/SymbolMacro.scala index 8ad86b4a9fb3..46a09b008eb8 100644 --- a/scala-package/macros/src/main/scala/org/apache/mxnet/SymbolMacro.scala +++ b/scala-package/macros/src/main/scala/org/apache/mxnet/SymbolMacro.scala @@ -66,8 +66,8 @@ private[mxnet] object SymbolImplMacros { val newFunctionDefs : List[DefDef] = newSymbolFunctions map { symbolfunction => - // TODO: Implement the codeGen + // Construct argument field var argDef = ListBuffer[String]() symbolfunction.listOfArgs.foreach(symbolarg => { val currArgName = if (symbolarg.argName.equals("var")) "vari" else symbolarg.argName @@ -80,7 +80,7 @@ private[mxnet] object SymbolImplMacros { }) argDef += "name : String = null" argDef += "attr : Map[String, String] = null" - + // Construct Implementation field var impl = ListBuffer[String]() impl += "val map = scala.collection.mutable.Map[String, Any]()" symbolfunction.listOfArgs.foreach({ symbolarg => @@ -92,7 +92,7 @@ private[mxnet] object SymbolImplMacros { impl += base }) impl += "createSymbolGeneral(\"" + symbolfunction.name + "\", name, attr, Seq(), map.toMap)" - + // Combine and build the function string val returnType = "Symbol" var finalStr = s"def ${symbolfunction.name}New" finalStr += s" (${argDef.mkString(",")}) : Symbol" @@ -158,7 +158,6 @@ private[mxnet] object SymbolImplMacros { if (spaceRemoved.charAt(0)== '{') { val endIdx = spaceRemoved.indexOf('}') commaRemoved = spaceRemoved.substring(endIdx + 1).split(",") - // commaRemoved(0) = spaceRemoved.substring(0, endIdx+1) commaRemoved(0) = "string" } else { commaRemoved = spaceRemoved.split(",") @@ -166,8 +165,6 @@ private[mxnet] object SymbolImplMacros { // Optional Field if (commaRemoved.length >= 3) { (typeConversion(commaRemoved(0), argType), true) - // TODO: Qing: do we set default value on our side? - // optionalField = " = " + conversion(typeConv, commaRemoved(2).split("=")(1)) } else if (commaRemoved.length == 2 || commaRemoved.length == 1) { val tempType = typeConversion(commaRemoved(0), argType) val tempOptional = tempType.equals("Symbol") From 8de007d94f6a6a88e032614ce4c5f2d89e83502b Mon Sep 17 00:00:00 2001 From: Qing Date: Thu, 26 Apr 2018 14:19:02 -0700 Subject: [PATCH 08/26] Add _contrib_ support --- .../scala/org/apache/mxnet/SymbolMacro.scala | 17 +++++++++-------- 1 file changed, 9 insertions(+), 8 deletions(-) diff --git a/scala-package/macros/src/main/scala/org/apache/mxnet/SymbolMacro.scala b/scala-package/macros/src/main/scala/org/apache/mxnet/SymbolMacro.scala index 46a09b008eb8..583732306822 100644 --- a/scala-package/macros/src/main/scala/org/apache/mxnet/SymbolMacro.scala +++ b/scala-package/macros/src/main/scala/org/apache/mxnet/SymbolMacro.scala @@ -93,9 +93,9 @@ private[mxnet] object SymbolImplMacros { }) impl += "createSymbolGeneral(\"" + symbolfunction.name + "\", name, attr, Seq(), map.toMap)" // Combine and build the function string - val returnType = "Symbol" + val returnType = "org.apache.mxnet.Symbol" var finalStr = s"def ${symbolfunction.name}New" - finalStr += s" (${argDef.mkString(",")}) : Symbol" + finalStr += s" (${argDef.mkString(",")}) : $returnType" finalStr += s" = {${impl.mkString("\n")}}" c.parse(finalStr).asInstanceOf[DefDef] } @@ -134,11 +134,11 @@ private[mxnet] object SymbolImplMacros { // Convert C++ Types to Scala Types private def typeConversion(in : String, argType : String = "") : String = { in match { - case "Shape(tuple)" | "ShapeorNone" => "Shape" - case "Symbol" | "NDArray" | "NDArray-or-Symbol" => "Symbol" + case "Shape(tuple)" | "ShapeorNone" => "org.apache.mxnet.Shape" + case "Symbol" | "NDArray" | "NDArray-or-Symbol" => "org.apache.mxnet.Symbol" case "Symbol[]" | "NDArray[]" | "NDArray-or-Symbol[]" | "SymbolorSymbol[]" - => "Array[Symbol]" - case "float" | "real_t" => "MXFloat" + => "Array[org.apache.mxnet.Symbol]" + case "float" | "real_t" | "floatorNone" => "org.apache.mxnet.Base.MXFloat" case "int" | "intorNone" | "int(non-negative)" => "Int" case "long" | "long(non-negative)" => "Long" case "double" => "Double" @@ -151,6 +151,7 @@ private[mxnet] object SymbolImplMacros { } + private def argumentCleaner(argType : String) : (String, Boolean) = { val spaceRemoved = argType.replaceAll("\\s+", "") var commaRemoved : Array[String] = new Array[String](0) @@ -167,7 +168,7 @@ private[mxnet] object SymbolImplMacros { (typeConversion(commaRemoved(0), argType), true) } else if (commaRemoved.length == 2 || commaRemoved.length == 1) { val tempType = typeConversion(commaRemoved(0), argType) - val tempOptional = tempType.equals("Symbol") + val tempOptional = tempType.equals("org.apache.mxnet.Symbol") (tempType, tempOptional) } else { throw new IllegalArgumentException( @@ -181,7 +182,7 @@ private[mxnet] object SymbolImplMacros { private def initSymbolModule(): List[SymbolFunction] = { val opNames = ListBuffer.empty[String] _LIB.mxListAllOpNames(opNames) - opNames.filter(!_.startsWith("_")).map(opName => { + opNames.filter(op => !op.startsWith("_") || op.startsWith("_contrib_")).map(opName => { val opHandle = new RefLong _LIB.nnGetOpHandle(opName, opHandle) makeAtomicSymbolFunction(opHandle.value, opName) From 51581e35f9a0189373b83eb42434eae251a288b8 Mon Sep 17 00:00:00 2001 From: Qing Date: Thu, 26 Apr 2018 14:25:03 -0700 Subject: [PATCH 09/26] Resolve Style issues --- .../mxnetexamples/imclassification/TrainMnist.scala | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/scala-package/examples/src/main/scala/org/apache/mxnetexamples/imclassification/TrainMnist.scala b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/imclassification/TrainMnist.scala index 699d3f6ad707..8bb9df89d16c 100644 --- a/scala-package/examples/src/main/scala/org/apache/mxnetexamples/imclassification/TrainMnist.scala +++ b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/imclassification/TrainMnist.scala @@ -31,7 +31,7 @@ object TrainMnist { def getMlp: Symbol = { val data = Symbol.Variable("data") - val fc1 = Symbol.FullyConnectedNew(data = Some(data), num_hidden = 128, name = "fc1") + val fc1 = Symbol.FullyConnectedNew(data = Some(data), num_hidden = 128, name = "fc1") val act1 = Symbol.ActivationNew(data = Some(fc1), "relu", name = "relu") val fc2 = Symbol.FullyConnectedNew(Some(act1), None, None, 64, name = "fc2") val act2 = Symbol.ActivationNew(data = Some(fc2), "relu", name = "relu2") @@ -49,11 +49,13 @@ object TrainMnist { // first conv val conv1 = Symbol.ConvolutionNew(data = Some(data), kernel = Shape(5, 5), num_filter = 20) val tanh1 = Symbol.tanhNew(data = Some(conv1)) - val pool1 = Symbol.PoolingNew(data = Some(tanh1), pool_type = Some("max"), kernel = Some(Shape(2, 2)), stride = Some(Shape(2, 2))) + val pool1 = Symbol.PoolingNew(data = Some(tanh1), pool_type = Some("max"), + kernel = Some(Shape(2, 2)), stride = Some(Shape(2, 2))) // second conv val conv2 = Symbol.ConvolutionNew(data = Some(pool1), kernel = Shape(5, 5), num_filter = 50) val tanh2 = Symbol.tanhNew(data = Some(conv2)) - val pool2 = Symbol.PoolingNew(data = Some(tanh2), pool_type = Some("max"), kernel = Some(Shape(2, 2)), stride = Some(Shape(2, 2))) + val pool2 = Symbol.PoolingNew(data = Some(tanh2), pool_type = Some("max"), + kernel = Some(Shape(2, 2)), stride = Some(Shape(2, 2))) // first fullc val flatten = Symbol.FlattenNew(data = Some(pool2)) val fc1 = Symbol.FullyConnectedNew(data = Some(flatten), num_hidden = 500) From e370c70368446e827107afb4fa804ffa910ecdb7 Mon Sep 17 00:00:00 2001 From: Qing Date: Mon, 30 Apr 2018 10:57:04 -0700 Subject: [PATCH 10/26] Add Depreciated to the old APIs --- .../macros/src/main/scala/org/apache/mxnet/SymbolMacro.scala | 1 + 1 file changed, 1 insertion(+) diff --git a/scala-package/macros/src/main/scala/org/apache/mxnet/SymbolMacro.scala b/scala-package/macros/src/main/scala/org/apache/mxnet/SymbolMacro.scala index 583732306822..eb83a615aaab 100644 --- a/scala-package/macros/src/main/scala/org/apache/mxnet/SymbolMacro.scala +++ b/scala-package/macros/src/main/scala/org/apache/mxnet/SymbolMacro.scala @@ -56,6 +56,7 @@ private[mxnet] object SymbolImplMacros { val funcName = symbolfunction.name val tName = TermName(funcName) q""" + @Deprecated def $tName(name : String = null, attr : Map[String, String] = null) (args : org.apache.mxnet.Symbol*)(kwargs : Map[String, Any] = null) : org.apache.mxnet.Symbol = { From 6d7d4bc7abb2ac9c1aeea75cafddcbe2aa8ec32a Mon Sep 17 00:00:00 2001 From: Qing Date: Tue, 1 May 2018 13:04:33 -0700 Subject: [PATCH 11/26] New namespace for Symbol API --- .../scala/org/apache/mxnet/NewSymbol.scala | 20 ++++ .../main/scala/org/apache/mxnet/Symbol.scala | 2 + .../imclassification/TrainMnist.scala | 34 +++---- .../scala/org/apache/mxnet/SymbolMacro.scala | 99 ++++++++++--------- 4 files changed, 94 insertions(+), 61 deletions(-) create mode 100644 scala-package/core/src/main/scala/org/apache/mxnet/NewSymbol.scala diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/NewSymbol.scala b/scala-package/core/src/main/scala/org/apache/mxnet/NewSymbol.scala new file mode 100644 index 000000000000..56138a7d0325 --- /dev/null +++ b/scala-package/core/src/main/scala/org/apache/mxnet/NewSymbol.scala @@ -0,0 +1,20 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.mxnet +@AddNewSymbolFunctions(false) +object NewSymbol { +} diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/Symbol.scala b/scala-package/core/src/main/scala/org/apache/mxnet/Symbol.scala index 13f85a731dc4..50eb617817f3 100644 --- a/scala-package/core/src/main/scala/org/apache/mxnet/Symbol.scala +++ b/scala-package/core/src/main/scala/org/apache/mxnet/Symbol.scala @@ -830,6 +830,8 @@ object Symbol { private val functions: Map[String, SymbolFunction] = initSymbolModule() private val bindReqMap = Map("null" -> 0, "write" -> 1, "add" -> 3) + val api = NewSymbol + def pow(sym1: Symbol, sym2: Symbol): Symbol = { Symbol.createFromListedSymbols("_Power")(Array(sym1, sym2)) } diff --git a/scala-package/examples/src/main/scala/org/apache/mxnetexamples/imclassification/TrainMnist.scala b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/imclassification/TrainMnist.scala index 8bb9df89d16c..e9171bd47c28 100644 --- a/scala-package/examples/src/main/scala/org/apache/mxnetexamples/imclassification/TrainMnist.scala +++ b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/imclassification/TrainMnist.scala @@ -31,12 +31,12 @@ object TrainMnist { def getMlp: Symbol = { val data = Symbol.Variable("data") - val fc1 = Symbol.FullyConnectedNew(data = Some(data), num_hidden = 128, name = "fc1") - val act1 = Symbol.ActivationNew(data = Some(fc1), "relu", name = "relu") - val fc2 = Symbol.FullyConnectedNew(Some(act1), None, None, 64, name = "fc2") - val act2 = Symbol.ActivationNew(data = Some(fc2), "relu", name = "relu2") - val fc3 = Symbol.FullyConnectedNew(Some(act2), None, None, 10, name = "fc3") - val mlp = Symbol.SoftmaxOutputNew(name = "softmax", data = Some(fc3)) + val fc1 = Symbol.api.FullyConnected(data = Some(data), num_hidden = 128, name = "fc1") + val act1 = Symbol.api.Activation (data = Some(fc1), "relu", name = "relu") + val fc2 = Symbol.api.FullyConnected(Some(act1), None, None, 64, name = "fc2") + val act2 = Symbol.api.Activation(data = Some(fc2), "relu", name = "relu2") + val fc3 = Symbol.api.FullyConnected(Some(act2), None, None, 10, name = "fc3") + val mlp = Symbol.api.SoftmaxOutput(name = "softmax", data = Some(fc3)) mlp } @@ -47,23 +47,23 @@ object TrainMnist { def getLenet: Symbol = { val data = Symbol.Variable("data") // first conv - val conv1 = Symbol.ConvolutionNew(data = Some(data), kernel = Shape(5, 5), num_filter = 20) - val tanh1 = Symbol.tanhNew(data = Some(conv1)) - val pool1 = Symbol.PoolingNew(data = Some(tanh1), pool_type = Some("max"), + val conv1 = Symbol.api.Convolution(data = Some(data), kernel = Shape(5, 5), num_filter = 20) + val tanh1 = Symbol.api.tanh(data = Some(conv1)) + val pool1 = Symbol.api.Pooling(data = Some(tanh1), pool_type = Some("max"), kernel = Some(Shape(2, 2)), stride = Some(Shape(2, 2))) // second conv - val conv2 = Symbol.ConvolutionNew(data = Some(pool1), kernel = Shape(5, 5), num_filter = 50) - val tanh2 = Symbol.tanhNew(data = Some(conv2)) - val pool2 = Symbol.PoolingNew(data = Some(tanh2), pool_type = Some("max"), + val conv2 = Symbol.api.Convolution(data = Some(pool1), kernel = Shape(5, 5), num_filter = 50) + val tanh2 = Symbol.api.tanh(data = Some(conv2)) + val pool2 = Symbol.api.Pooling(data = Some(tanh2), pool_type = Some("max"), kernel = Some(Shape(2, 2)), stride = Some(Shape(2, 2))) // first fullc - val flatten = Symbol.FlattenNew(data = Some(pool2)) - val fc1 = Symbol.FullyConnectedNew(data = Some(flatten), num_hidden = 500) - val tanh3 = Symbol.tanhNew(data = Some(fc1)) + val flatten = Symbol.api.Flatten(data = Some(pool2)) + val fc1 = Symbol.api.FullyConnected(data = Some(flatten), num_hidden = 500) + val tanh3 = Symbol.api.tanh(data = Some(fc1)) // second fullc - val fc2 = Symbol.FullyConnectedNew(data = Some(tanh3), num_hidden = 10) + val fc2 = Symbol.api.FullyConnected(data = Some(tanh3), num_hidden = 10) // loss - val lenet = Symbol.SoftmaxOutputNew(name = "softmax", data = Some(fc2)) + val lenet = Symbol.api.SoftmaxOutput(name = "softmax", data = Some(fc2)) lenet } diff --git a/scala-package/macros/src/main/scala/org/apache/mxnet/SymbolMacro.scala b/scala-package/macros/src/main/scala/org/apache/mxnet/SymbolMacro.scala index eb83a615aaab..c74a63c4389f 100644 --- a/scala-package/macros/src/main/scala/org/apache/mxnet/SymbolMacro.scala +++ b/scala-package/macros/src/main/scala/org/apache/mxnet/SymbolMacro.scala @@ -28,23 +28,32 @@ private[mxnet] class AddSymbolFunctions(isContrib: Boolean) extends StaticAnnota private[mxnet] def macroTransform(annottees: Any*) = macro SymbolImplMacros.addDefs } +private[mxnet] class AddNewSymbolFunctions(isContrib: Boolean) extends StaticAnnotation { + private[mxnet] def macroTransform(annottees: Any*) = macro SymbolImplMacros.addNewDefs +} + private[mxnet] object SymbolImplMacros { case class SymbolArg(argName: String, argType: String, isOptional : Boolean) case class SymbolFunction(name: String, listOfArgs: List[SymbolArg]) // scalastyle:off havetype def addDefs(c: blackbox.Context)(annottees: c.Expr[Any]*) = { - impl(c)(false, annottees: _*) + impl(c)(false, false, annottees: _*) } - // scalastyle:off havetype + def addNewDefs(c: blackbox.Context)(annottees: c.Expr[Any]*) = { + impl(c)(false, true, annottees: _*) + } + // scalastyle:on havetype private val symbolFunctions: List[SymbolFunction] = initSymbolModule() - private def impl(c: blackbox.Context)(addSuper: Boolean, annottees: c.Expr[Any]*): c.Expr[Any] = { + private def impl(c: blackbox.Context)(addSuper: Boolean, + newAPI: Boolean, annottees: c.Expr[Any]*): c.Expr[Any] = { import c.universe._ val isContrib: Boolean = c.prefix.tree match { case q"new AddSymbolFunctions($b)" => c.eval[Boolean](c.Expr(b)) + case q"new AddNewSymbolFunctions($b)" => c.eval[Boolean](c.Expr(b)) } val newSymbolFunctions = { @@ -52,58 +61,60 @@ private[mxnet] object SymbolImplMacros { else symbolFunctions.filter(!_.name.startsWith("_contrib_")) } - var functionDefs = newSymbolFunctions map { symbolfunction => + var functionDefs = List[DefDef]() + + if (!newAPI) { + functionDefs = newSymbolFunctions map { symbolfunction => val funcName = symbolfunction.name val tName = TermName(funcName) q""" - @Deprecated def $tName(name : String = null, attr : Map[String, String] = null) (args : org.apache.mxnet.Symbol*)(kwargs : Map[String, Any] = null) : org.apache.mxnet.Symbol = { createSymbolGeneral($funcName,name,attr,args,kwargs) } - """ - } - - - val newFunctionDefs : List[DefDef] = newSymbolFunctions map { symbolfunction => - - // Construct argument field - var argDef = ListBuffer[String]() - symbolfunction.listOfArgs.foreach(symbolarg => { - val currArgName = if (symbolarg.argName.equals("var")) "vari" else symbolarg.argName - if (symbolarg.isOptional) { - argDef += s"${currArgName} : Option[${symbolarg.argType}] = None" - } - else { - argDef += s"${currArgName} : ${symbolarg.argType}" - } - }) - argDef += "name : String = null" - argDef += "attr : Map[String, String] = null" - // Construct Implementation field - var impl = ListBuffer[String]() - impl += "val map = scala.collection.mutable.Map[String, Any]()" - symbolfunction.listOfArgs.foreach({ symbolarg => - val currArgName = if (symbolarg.argName.equals("var")) "vari" else symbolarg.argName - var base = "map(\"" + symbolarg.argName + "\") = " + currArgName - if (symbolarg.isOptional) { - base = "if (!" + currArgName + ".isEmpty)" + base + ".get" - } - impl += base - }) - impl += "createSymbolGeneral(\"" + symbolfunction.name + "\", name, attr, Seq(), map.toMap)" - // Combine and build the function string - val returnType = "org.apache.mxnet.Symbol" - var finalStr = s"def ${symbolfunction.name}New" - finalStr += s" (${argDef.mkString(",")}) : $returnType" - finalStr += s" = {${impl.mkString("\n")}}" - c.parse(finalStr).asInstanceOf[DefDef] + """.asInstanceOf[DefDef] + } + } else { + functionDefs = newSymbolFunctions map { symbolfunction => + + // Construct argument field + var argDef = ListBuffer[String]() + symbolfunction.listOfArgs.foreach(symbolarg => { + val currArgName = if (symbolarg.argName.equals("var")) "vari" else symbolarg.argName + if (symbolarg.isOptional) { + argDef += s"${currArgName} : Option[${symbolarg.argType}] = None" + } + else { + argDef += s"${currArgName} : ${symbolarg.argType}" + } + }) + argDef += "name : String = null" + argDef += "attr : Map[String, String] = null" + // Construct Implementation field + var impl = ListBuffer[String]() + impl += "val map = scala.collection.mutable.Map[String, Any]()" + symbolfunction.listOfArgs.foreach({ symbolarg => + val currArgName = if (symbolarg.argName.equals("var")) "vari" else symbolarg.argName + var base = "map(\"" + symbolarg.argName + "\") = " + currArgName + if (symbolarg.isOptional) { + base = "if (!" + currArgName + ".isEmpty)" + base + ".get" + } + impl += base + }) + // scalastyle:off + impl += "org.apache.mxnet.Symbol.createSymbolGeneral(\"" + symbolfunction.name + "\", name, attr, Seq(), map.toMap)" + // scalastyle:on + // Combine and build the function string + val returnType = "org.apache.mxnet.Symbol" + var finalStr = s"def ${symbolfunction.name}" + finalStr += s" (${argDef.mkString(",")}) : $returnType" + finalStr += s" = {${impl.mkString("\n")}}" + c.parse(finalStr).asInstanceOf[DefDef] + } } - functionDefs = functionDefs ::: newFunctionDefs - val inputs = annottees.map(_.tree).toList // pattern match on the inputs From c1cbf16fa84057439480c9468473627e4ec4ae1a Mon Sep 17 00:00:00 2001 From: Qing Date: Tue, 1 May 2018 13:24:15 -0700 Subject: [PATCH 12/26] Add require tracker to check the input format --- .../macros/src/main/scala/org/apache/mxnet/SymbolMacro.scala | 3 +++ 1 file changed, 3 insertions(+) diff --git a/scala-package/macros/src/main/scala/org/apache/mxnet/SymbolMacro.scala b/scala-package/macros/src/main/scala/org/apache/mxnet/SymbolMacro.scala index c74a63c4389f..bcba8c58de8e 100644 --- a/scala-package/macros/src/main/scala/org/apache/mxnet/SymbolMacro.scala +++ b/scala-package/macros/src/main/scala/org/apache/mxnet/SymbolMacro.scala @@ -177,6 +177,9 @@ private[mxnet] object SymbolImplMacros { } // Optional Field if (commaRemoved.length >= 3) { + // arg: Type, optional, default = Null + require(commaRemoved(1).equals("optional")) + require(commaRemoved(2).startsWith("default=")) (typeConversion(commaRemoved(0), argType), true) } else if (commaRemoved.length == 2 || commaRemoved.length == 1) { val tempType = typeConversion(commaRemoved(0), argType) From 0d34b942b2966c7b3bd97d1def5fd7e4de27b5d3 Mon Sep 17 00:00:00 2001 From: Qing Date: Tue, 1 May 2018 15:01:27 -0700 Subject: [PATCH 13/26] Trigger CI From 4268c7e1b24bed8734f10d9ac7fc7dd2e52d3784 Mon Sep 17 00:00:00 2001 From: Qing Date: Tue, 1 May 2018 16:35:48 -0700 Subject: [PATCH 14/26] Trigger CI From b6fe9420703e06648113de739ee61f0491dfe79f Mon Sep 17 00:00:00 2001 From: Qing Date: Wed, 2 May 2018 14:08:15 -0700 Subject: [PATCH 15/26] Change names and add comments --- .../main/scala/org/apache/mxnet/Symbol.scala | 2 +- .../mxnet/{NewSymbol.scala => SymbolAPI.scala} | 4 ++-- .../scala/org/apache/mxnet/SymbolMacro.scala | 18 +++++++++++++++--- 3 files changed, 18 insertions(+), 6 deletions(-) rename scala-package/core/src/main/scala/org/apache/mxnet/{NewSymbol.scala => SymbolAPI.scala} (94%) diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/Symbol.scala b/scala-package/core/src/main/scala/org/apache/mxnet/Symbol.scala index 50eb617817f3..60efd2ba62bd 100644 --- a/scala-package/core/src/main/scala/org/apache/mxnet/Symbol.scala +++ b/scala-package/core/src/main/scala/org/apache/mxnet/Symbol.scala @@ -830,7 +830,7 @@ object Symbol { private val functions: Map[String, SymbolFunction] = initSymbolModule() private val bindReqMap = Map("null" -> 0, "write" -> 1, "add" -> 3) - val api = NewSymbol + val api = SymbolAPI def pow(sym1: Symbol, sym2: Symbol): Symbol = { Symbol.createFromListedSymbols("_Power")(Array(sym1, sym2)) diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/NewSymbol.scala b/scala-package/core/src/main/scala/org/apache/mxnet/SymbolAPI.scala similarity index 94% rename from scala-package/core/src/main/scala/org/apache/mxnet/NewSymbol.scala rename to scala-package/core/src/main/scala/org/apache/mxnet/SymbolAPI.scala index 56138a7d0325..d5dd401a82e9 100644 --- a/scala-package/core/src/main/scala/org/apache/mxnet/NewSymbol.scala +++ b/scala-package/core/src/main/scala/org/apache/mxnet/SymbolAPI.scala @@ -15,6 +15,6 @@ * limitations under the License. */ package org.apache.mxnet -@AddNewSymbolFunctions(false) -object NewSymbol { +@AddSymbolAPIs(false) +object SymbolAPI { } diff --git a/scala-package/macros/src/main/scala/org/apache/mxnet/SymbolMacro.scala b/scala-package/macros/src/main/scala/org/apache/mxnet/SymbolMacro.scala index bcba8c58de8e..6c6d84d291ab 100644 --- a/scala-package/macros/src/main/scala/org/apache/mxnet/SymbolMacro.scala +++ b/scala-package/macros/src/main/scala/org/apache/mxnet/SymbolMacro.scala @@ -28,7 +28,7 @@ private[mxnet] class AddSymbolFunctions(isContrib: Boolean) extends StaticAnnota private[mxnet] def macroTransform(annottees: Any*) = macro SymbolImplMacros.addDefs } -private[mxnet] class AddNewSymbolFunctions(isContrib: Boolean) extends StaticAnnotation { +private[mxnet] class AddSymbolAPIs(isContrib: Boolean) extends StaticAnnotation { private[mxnet] def macroTransform(annottees: Any*) = macro SymbolImplMacros.addNewDefs } @@ -53,7 +53,7 @@ private[mxnet] object SymbolImplMacros { val isContrib: Boolean = c.prefix.tree match { case q"new AddSymbolFunctions($b)" => c.eval[Boolean](c.Expr(b)) - case q"new AddNewSymbolFunctions($b)" => c.eval[Boolean](c.Expr(b)) + case q"new AddSymbolAPIs($b)" => c.eval[Boolean](c.Expr(b)) } val newSymbolFunctions = { @@ -95,6 +95,8 @@ private[mxnet] object SymbolImplMacros { var impl = ListBuffer[String]() impl += "val map = scala.collection.mutable.Map[String, Any]()" symbolfunction.listOfArgs.foreach({ symbolarg => + // var is a special word used to define variable in Scala, + // need to changed to something else in order to make it work val currArgName = if (symbolarg.argName.equals("var")) "vari" else symbolarg.argName var base = "map(\"" + symbolarg.argName + "\") = " + currArgName if (symbolarg.isOptional) { @@ -163,7 +165,16 @@ private[mxnet] object SymbolImplMacros { } - + /** + * By default, the argType come from the C++ API is a description more than a single word + * For Example: + * , , + * The three field shown above do not usually come at the same time + * This function used the above format to determine if the argument is + * optional, what is it Scala type and possibly pass in a default value + * @param argType Raw arguement Type description + * @return (Scala_Type, isOptional) + */ private def argumentCleaner(argType : String) : (String, Boolean) = { val spaceRemoved = argType.replaceAll("\\s+", "") var commaRemoved : Array[String] = new Array[String](0) @@ -197,6 +208,7 @@ private[mxnet] object SymbolImplMacros { private def initSymbolModule(): List[SymbolFunction] = { val opNames = ListBuffer.empty[String] _LIB.mxListAllOpNames(opNames) + // TODO: Add '_linalg_', '_sparse_', '_image_' support opNames.filter(op => !op.startsWith("_") || op.startsWith("_contrib_")).map(opName => { val opHandle = new RefLong _LIB.nnGetOpHandle(opName, opHandle) From f3b1d963884030906b41cb5f096254d18d528da2 Mon Sep 17 00:00:00 2001 From: Qing Date: Wed, 2 May 2018 16:15:21 -0700 Subject: [PATCH 16/26] Add underscore support --- .../scala/org/apache/mxnet/SymbolMacro.scala | 18 +++++++++++++----- 1 file changed, 13 insertions(+), 5 deletions(-) diff --git a/scala-package/macros/src/main/scala/org/apache/mxnet/SymbolMacro.scala b/scala-package/macros/src/main/scala/org/apache/mxnet/SymbolMacro.scala index 6c6d84d291ab..0bd8db5e23c4 100644 --- a/scala-package/macros/src/main/scala/org/apache/mxnet/SymbolMacro.scala +++ b/scala-package/macros/src/main/scala/org/apache/mxnet/SymbolMacro.scala @@ -81,7 +81,11 @@ private[mxnet] object SymbolImplMacros { // Construct argument field var argDef = ListBuffer[String]() symbolfunction.listOfArgs.foreach(symbolarg => { - val currArgName = if (symbolarg.argName.equals("var")) "vari" else symbolarg.argName + val currArgName = symbolarg.argName match { + case "var" => "vari" + case "type" => "typeOf" + case default => symbolarg.argName + } if (symbolarg.isOptional) { argDef += s"${currArgName} : Option[${symbolarg.argType}] = None" } @@ -97,7 +101,11 @@ private[mxnet] object SymbolImplMacros { symbolfunction.listOfArgs.foreach({ symbolarg => // var is a special word used to define variable in Scala, // need to changed to something else in order to make it work - val currArgName = if (symbolarg.argName.equals("var")) "vari" else symbolarg.argName + val currArgName = symbolarg.argName match { + case "var" => "vari" + case "type" => "typeOf" + case default => symbolarg.argName + } var base = "map(\"" + symbolarg.argName + "\") = " + currArgName if (symbolarg.isOptional) { base = "if (!" + currArgName + ".isEmpty)" + base + ".get" @@ -155,10 +163,10 @@ private[mxnet] object SymbolImplMacros { case "float" | "real_t" | "floatorNone" => "org.apache.mxnet.Base.MXFloat" case "int" | "intorNone" | "int(non-negative)" => "Int" case "long" | "long(non-negative)" => "Long" - case "double" => "Double" + case "double" | "doubleorNone" => "Double" case "string" => "String" case "boolean" => "Boolean" - case "tupleof" => "Any" + case "tupleof" | "tupleof" | "ptr" | "" => "Any" case default => throw new IllegalArgumentException( s"Invalid type for args: $default, $argType") } @@ -209,7 +217,7 @@ private[mxnet] object SymbolImplMacros { val opNames = ListBuffer.empty[String] _LIB.mxListAllOpNames(opNames) // TODO: Add '_linalg_', '_sparse_', '_image_' support - opNames.filter(op => !op.startsWith("_") || op.startsWith("_contrib_")).map(opName => { + opNames.map(opName => { val opHandle = new RefLong _LIB.nnGetOpHandle(opName, opHandle) makeAtomicSymbolFunction(opHandle.value, opName) From 5c05c8d3fa285abc6d3ed6f284d1e344702e2fce Mon Sep 17 00:00:00 2001 From: Qing Date: Wed, 2 May 2018 16:32:28 -0700 Subject: [PATCH 17/26] Disable underscore function from generation --- .../macros/src/main/scala/org/apache/mxnet/SymbolMacro.scala | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/scala-package/macros/src/main/scala/org/apache/mxnet/SymbolMacro.scala b/scala-package/macros/src/main/scala/org/apache/mxnet/SymbolMacro.scala index 0bd8db5e23c4..67b84960d9a9 100644 --- a/scala-package/macros/src/main/scala/org/apache/mxnet/SymbolMacro.scala +++ b/scala-package/macros/src/main/scala/org/apache/mxnet/SymbolMacro.scala @@ -57,8 +57,9 @@ private[mxnet] object SymbolImplMacros { } val newSymbolFunctions = { - if (isContrib) symbolFunctions.filter(_.name.startsWith("_contrib_")) - else symbolFunctions.filter(!_.name.startsWith("_contrib_")) + if (isContrib) symbolFunctions.filter( + func => func.name.startsWith("_contrib_") || !func.name.startsWith("_")) + else symbolFunctions.filter(!_.name.startsWith("_")) } var functionDefs = List[DefDef]() From 006d3b50e28e5e7ac757496b9b3b903cf341f9a4 Mon Sep 17 00:00:00 2001 From: Qing Date: Thu, 3 May 2018 13:14:16 -0700 Subject: [PATCH 18/26] Trigger CI From 0dd3ee17c68ff8c38845d913ece5b7ee3e8ba1c4 Mon Sep 17 00:00:00 2001 From: Qing Date: Mon, 7 May 2018 11:22:17 -0700 Subject: [PATCH 19/26] use different impl method for new API --- .../scala/org/apache/mxnet/SymbolMacro.scala | 140 +++++++++++------- 1 file changed, 84 insertions(+), 56 deletions(-) diff --git a/scala-package/macros/src/main/scala/org/apache/mxnet/SymbolMacro.scala b/scala-package/macros/src/main/scala/org/apache/mxnet/SymbolMacro.scala index 67b84960d9a9..5c77fc1c9411 100644 --- a/scala-package/macros/src/main/scala/org/apache/mxnet/SymbolMacro.scala +++ b/scala-package/macros/src/main/scala/org/apache/mxnet/SymbolMacro.scala @@ -38,22 +38,23 @@ private[mxnet] object SymbolImplMacros { // scalastyle:off havetype def addDefs(c: blackbox.Context)(annottees: c.Expr[Any]*) = { - impl(c)(false, false, annottees: _*) + impl(c)(annottees: _*) } def addNewDefs(c: blackbox.Context)(annottees: c.Expr[Any]*) = { - impl(c)(false, true, annottees: _*) + newAPIImpl(c)(annottees: _*) } // scalastyle:on havetype private val symbolFunctions: List[SymbolFunction] = initSymbolModule() - private def impl(c: blackbox.Context)(addSuper: Boolean, - newAPI: Boolean, annottees: c.Expr[Any]*): c.Expr[Any] = { + /** + * Implementation for fixed input API structure + */ + private def impl(c: blackbox.Context)(annottees: c.Expr[Any]*): c.Expr[Any] = { import c.universe._ val isContrib: Boolean = c.prefix.tree match { case q"new AddSymbolFunctions($b)" => c.eval[Boolean](c.Expr(b)) - case q"new AddSymbolAPIs($b)" => c.eval[Boolean](c.Expr(b)) } val newSymbolFunctions = { @@ -62,10 +63,8 @@ private[mxnet] object SymbolImplMacros { else symbolFunctions.filter(!_.name.startsWith("_")) } - var functionDefs = List[DefDef]() - if (!newAPI) { - functionDefs = newSymbolFunctions map { symbolfunction => + val functionDefs = newSymbolFunctions map { symbolfunction => val funcName = symbolfunction.name val tName = TermName(funcName) q""" @@ -76,64 +75,93 @@ private[mxnet] object SymbolImplMacros { } """.asInstanceOf[DefDef] } - } else { - functionDefs = newSymbolFunctions map { symbolfunction => - // Construct argument field - var argDef = ListBuffer[String]() - symbolfunction.listOfArgs.foreach(symbolarg => { - val currArgName = symbolarg.argName match { - case "var" => "vari" - case "type" => "typeOf" - case default => symbolarg.argName - } - if (symbolarg.isOptional) { - argDef += s"${currArgName} : Option[${symbolarg.argType}] = None" - } - else { - argDef += s"${currArgName} : ${symbolarg.argType}" - } - }) - argDef += "name : String = null" - argDef += "attr : Map[String, String] = null" - // Construct Implementation field - var impl = ListBuffer[String]() - impl += "val map = scala.collection.mutable.Map[String, Any]()" - symbolfunction.listOfArgs.foreach({ symbolarg => - // var is a special word used to define variable in Scala, - // need to changed to something else in order to make it work - val currArgName = symbolarg.argName match { - case "var" => "vari" - case "type" => "typeOf" - case default => symbolarg.argName - } - var base = "map(\"" + symbolarg.argName + "\") = " + currArgName - if (symbolarg.isOptional) { - base = "if (!" + currArgName + ".isEmpty)" + base + ".get" - } - impl += base - }) - // scalastyle:off - impl += "org.apache.mxnet.Symbol.createSymbolGeneral(\"" + symbolfunction.name + "\", name, attr, Seq(), map.toMap)" - // scalastyle:on - // Combine and build the function string - val returnType = "org.apache.mxnet.Symbol" - var finalStr = s"def ${symbolfunction.name}" - finalStr += s" (${argDef.mkString(",")}) : $returnType" - finalStr += s" = {${impl.mkString("\n")}}" - c.parse(finalStr).asInstanceOf[DefDef] - } + structGeneration(c)(functionDefs, annottees : _*) + } + + /** + * Implementation for Dynamic typed API Symbol.api. + */ + private def newAPIImpl(c: blackbox.Context)(annottees: c.Expr[Any]*) : c.Expr[Any] = { + import c.universe._ + + val isContrib: Boolean = c.prefix.tree match { + case q"new AddSymbolAPIs($b)" => c.eval[Boolean](c.Expr(b)) } + val newSymbolFunctions = { + if (isContrib) symbolFunctions.filter( + func => func.name.startsWith("_contrib_") || !func.name.startsWith("_")) + else symbolFunctions.filter(!_.name.startsWith("_")) + } + + val functionDefs = newSymbolFunctions map { symbolfunction => + // Construct argument field + var argDef = ListBuffer[String]() + symbolfunction.listOfArgs.foreach(symbolarg => { + val currArgName = symbolarg.argName match { + case "var" => "vari" + case "type" => "typeOf" + case default => symbolarg.argName + } + if (symbolarg.isOptional) { + argDef += s"${currArgName} : Option[${symbolarg.argType}] = None" + } + else { + argDef += s"${currArgName} : ${symbolarg.argType}" + } + }) + argDef += "name : String = null" + argDef += "attr : Map[String, String] = null" + // Construct Implementation field + var impl = ListBuffer[String]() + impl += "val map = scala.collection.mutable.Map[String, Any]()" + symbolfunction.listOfArgs.foreach({ symbolarg => + // var is a special word used to define variable in Scala, + // need to changed to something else in order to make it work + val currArgName = symbolarg.argName match { + case "var" => "vari" + case "type" => "typeOf" + case default => symbolarg.argName + } + var base = "map(\"" + symbolarg.argName + "\") = " + currArgName + if (symbolarg.isOptional) { + base = "if (!" + currArgName + ".isEmpty)" + base + ".get" + } + impl += base + }) + // scalastyle:off + impl += "org.apache.mxnet.Symbol.createSymbolGeneral(\"" + symbolfunction.name + "\", name, attr, Seq(), map.toMap)" + // scalastyle:on + // Combine and build the function string + val returnType = "org.apache.mxnet.Symbol" + var finalStr = s"def ${symbolfunction.name}" + finalStr += s" (${argDef.mkString(",")}) : $returnType" + finalStr += s" = {${impl.mkString("\n")}}" + c.parse(finalStr).asInstanceOf[DefDef] + } + structGeneration(c)(functionDefs, annottees : _*) + } + /** + * Generate class structure for all function APIs + * @param c + * @param funcDef DefDef type of function definitions + * @param annottees + * @return + */ + private def structGeneration(c: blackbox.Context) + (funcDef : List[c.universe.DefDef], annottees: c.Expr[Any]*) + : c.Expr[Any] = { + import c.universe._ val inputs = annottees.map(_.tree).toList // pattern match on the inputs val modDefs = inputs map { case ClassDef(mods, name, something, template) => val q = template match { case Template(superMaybe, emptyValDef, defs) => - Template(superMaybe, emptyValDef, defs ++ functionDefs) + Template(superMaybe, emptyValDef, defs ++ funcDef) case ex => throw new IllegalArgumentException(s"Invalid template: $ex") } @@ -141,7 +169,7 @@ private[mxnet] object SymbolImplMacros { case ModuleDef(mods, name, template) => val q = template match { case Template(superMaybe, emptyValDef, defs) => - Template(superMaybe, emptyValDef, defs ++ functionDefs) + Template(superMaybe, emptyValDef, defs ++ funcDef) case ex => throw new IllegalArgumentException(s"Invalid template: $ex") } From 80753d648347be3df7a555b09eca8867f88fe441 Mon Sep 17 00:00:00 2001 From: Qing Date: Mon, 7 May 2018 15:46:12 -0700 Subject: [PATCH 20/26] add unit test to Macros --- .../scala/org/apache/mxnet/init/Base.scala | 6 +- scala-package/macros/pom.xml | 38 +++++++++++++ .../scala/org/apache/mxnet/SymbolMacro.scala | 4 +- .../src/test/resources/log4j.properties | 24 ++++++++ .../scala/org/apache/mxnet/MacrosSuite.scala | 57 +++++++++++++++++++ 5 files changed, 126 insertions(+), 3 deletions(-) create mode 100644 scala-package/macros/src/test/resources/log4j.properties create mode 100644 scala-package/macros/src/test/scala/org/apache/mxnet/MacrosSuite.scala diff --git a/scala-package/init/src/main/scala/org/apache/mxnet/init/Base.scala b/scala-package/init/src/main/scala/org/apache/mxnet/init/Base.scala index 7af2e052255c..400ecce7fea4 100644 --- a/scala-package/init/src/main/scala/org/apache/mxnet/init/Base.scala +++ b/scala-package/init/src/main/scala/org/apache/mxnet/init/Base.scala @@ -37,7 +37,11 @@ object Base { @throws(classOf[UnsatisfiedLinkError]) private def tryLoadInitLibrary(): Unit = { - val baseDir = System.getProperty("user.dir") + "/init-native" + // val baseDir = System.getProperty("user.dir") + "/init-native" + var baseDir = System.getProperty("user.dir") + "/init-native" + if (System.getenv().containsKey("BASEDIR")) { + baseDir = sys.env("BASEDIR") + } val os = System.getProperty("os.name") // ref: http://lopica.sourceforge.net/os.html if (os.startsWith("Linux")) { diff --git a/scala-package/macros/pom.xml b/scala-package/macros/pom.xml index 0aa3030e7ce3..fbc8c4061a6a 100644 --- a/scala-package/macros/pom.xml +++ b/scala-package/macros/pom.xml @@ -52,4 +52,42 @@ ${libtype} + + + + + org.apache.maven.plugins + maven-jar-plugin + + + META-INF/*.SF + META-INF/*.DSA + META-INF/*.RSA + + + + + org.apache.maven.plugins + maven-compiler-plugin + + + org.scalatest + scalatest-maven-plugin + + + ${project.parent.basedir}/init-native + + + -Djava.library.path=${project.parent.basedir}/native/${platform}/target \ + -Dlog4j.configuration=file://${project.basedir}/src/test/resources/log4j.properties + + + + + org.scalastyle + scalastyle-maven-plugin + + + + diff --git a/scala-package/macros/src/main/scala/org/apache/mxnet/SymbolMacro.scala b/scala-package/macros/src/main/scala/org/apache/mxnet/SymbolMacro.scala index 5c77fc1c9411..35e9a2085d77 100644 --- a/scala-package/macros/src/main/scala/org/apache/mxnet/SymbolMacro.scala +++ b/scala-package/macros/src/main/scala/org/apache/mxnet/SymbolMacro.scala @@ -183,7 +183,7 @@ private[mxnet] object SymbolImplMacros { } // Convert C++ Types to Scala Types - private def typeConversion(in : String, argType : String = "") : String = { + def typeConversion(in : String, argType : String = "") : String = { in match { case "Shape(tuple)" | "ShapeorNone" => "org.apache.mxnet.Shape" case "Symbol" | "NDArray" | "NDArray-or-Symbol" => "org.apache.mxnet.Symbol" @@ -212,7 +212,7 @@ private[mxnet] object SymbolImplMacros { * @param argType Raw arguement Type description * @return (Scala_Type, isOptional) */ - private def argumentCleaner(argType : String) : (String, Boolean) = { + def argumentCleaner(argType : String) : (String, Boolean) = { val spaceRemoved = argType.replaceAll("\\s+", "") var commaRemoved : Array[String] = new Array[String](0) // Deal with the case e.g: stype : {'csr', 'default', 'row_sparse'} diff --git a/scala-package/macros/src/test/resources/log4j.properties b/scala-package/macros/src/test/resources/log4j.properties new file mode 100644 index 000000000000..d82fd7ea4f3d --- /dev/null +++ b/scala-package/macros/src/test/resources/log4j.properties @@ -0,0 +1,24 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# for development debugging +log4j.rootLogger = debug, stdout + +log4j.appender.stdout = org.apache.log4j.ConsoleAppender +log4j.appender.stdout.Target = System.out +log4j.appender.stdout.layout = org.apache.log4j.PatternLayout +log4j.appender.stdout.layout.ConversionPattern=%d{yyyy-MM-dd HH:mm:ss,SSS} [%t] [%c] [%p] - %m%n diff --git a/scala-package/macros/src/test/scala/org/apache/mxnet/MacrosSuite.scala b/scala-package/macros/src/test/scala/org/apache/mxnet/MacrosSuite.scala new file mode 100644 index 000000000000..86e92f15f605 --- /dev/null +++ b/scala-package/macros/src/test/scala/org/apache/mxnet/MacrosSuite.scala @@ -0,0 +1,57 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mxnet + +import org.scalatest.{BeforeAndAfterAll, FunSuite} +import org.slf4j.LoggerFactory + +class MacrosSuite extends FunSuite with BeforeAndAfterAll { + + private val logger = LoggerFactory.getLogger(classOf[MacrosSuite]) + + + override def beforeAll() { + } + + override def afterAll(): Unit = { + + } + + test("MacrosSuite-testArgumentCleaner") { + val input = List( + "Symbol, optional, default = Null", + "int, required", + "Shape(tuple), optional, default = []", + "{'csr', 'default', 'row_sparse'}, optional, default = 'csr'", + ", required" + ) + val output = List( + ("org.apache.mxnet.Symbol", true), + ("Int", false), + ("org.apache.mxnet.Shape", true), + ("String", true), + ("Any", false) + ) + + for (idx <- input.indices) { + val result = SymbolImplMacros.argumentCleaner(input(idx)) + assert(result._1 === output(idx)._1 && result._2 === output(idx)._2) + } + } + +} From 57984a29d04b089cadcaf9e4e5e7a9110183bd7b Mon Sep 17 00:00:00 2001 From: Qing Date: Tue, 8 May 2018 10:25:25 -0700 Subject: [PATCH 21/26] retrigger CI From e83648047e7e0bccde0dc6e6cab2f9110da01098 Mon Sep 17 00:00:00 2001 From: Qing Date: Wed, 9 May 2018 15:20:46 -0700 Subject: [PATCH 22/26] reTrigger CI From 788dddeb1eed99f4acc1be7aef53a574ce0c5a47 Mon Sep 17 00:00:00 2001 From: Qing Date: Thu, 10 May 2018 15:46:34 -0700 Subject: [PATCH 23/26] Name and minor fixes --- .../scala/org/apache/mxnet/SymbolAPI.scala | 6 +++++ .../scala/org/apache/mxnet/init/Base.scala | 5 ++-- scala-package/macros/pom.xml | 2 +- .../scala/org/apache/mxnet/SymbolMacro.scala | 27 +++++++------------ .../scala/org/apache/mxnet/MacrosSuite.scala | 7 ----- 5 files changed, 19 insertions(+), 28 deletions(-) diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/SymbolAPI.scala b/scala-package/core/src/main/scala/org/apache/mxnet/SymbolAPI.scala index d5dd401a82e9..49de9ae73218 100644 --- a/scala-package/core/src/main/scala/org/apache/mxnet/SymbolAPI.scala +++ b/scala-package/core/src/main/scala/org/apache/mxnet/SymbolAPI.scala @@ -15,6 +15,12 @@ * limitations under the License. */ package org.apache.mxnet + + @AddSymbolAPIs(false) +/** + * typesafe Symbol API: Symbol.api._ + * Main code will be generated during compile time through Macros + */ object SymbolAPI { } diff --git a/scala-package/init/src/main/scala/org/apache/mxnet/init/Base.scala b/scala-package/init/src/main/scala/org/apache/mxnet/init/Base.scala index 400ecce7fea4..80b4bd0daaa8 100644 --- a/scala-package/init/src/main/scala/org/apache/mxnet/init/Base.scala +++ b/scala-package/init/src/main/scala/org/apache/mxnet/init/Base.scala @@ -37,10 +37,9 @@ object Base { @throws(classOf[UnsatisfiedLinkError]) private def tryLoadInitLibrary(): Unit = { - // val baseDir = System.getProperty("user.dir") + "/init-native" var baseDir = System.getProperty("user.dir") + "/init-native" - if (System.getenv().containsKey("BASEDIR")) { - baseDir = sys.env("BASEDIR") + if (System.getenv().containsKey("MXNET_SCALA_MACRO_BASEDIR")) { + baseDir = sys.env("MXNET_SCALA_MACRO_BASEDIR") } val os = System.getProperty("os.name") // ref: http://lopica.sourceforge.net/os.html diff --git a/scala-package/macros/pom.xml b/scala-package/macros/pom.xml index fbc8c4061a6a..8eea759c18a3 100644 --- a/scala-package/macros/pom.xml +++ b/scala-package/macros/pom.xml @@ -75,7 +75,7 @@ scalatest-maven-plugin - ${project.parent.basedir}/init-native + ${project.parent.basedir}/init-native -Djava.library.path=${project.parent.basedir}/native/${platform}/target \ diff --git a/scala-package/macros/src/main/scala/org/apache/mxnet/SymbolMacro.scala b/scala-package/macros/src/main/scala/org/apache/mxnet/SymbolMacro.scala index 35e9a2085d77..cc8d5444565b 100644 --- a/scala-package/macros/src/main/scala/org/apache/mxnet/SymbolMacro.scala +++ b/scala-package/macros/src/main/scala/org/apache/mxnet/SymbolMacro.scala @@ -29,7 +29,7 @@ private[mxnet] class AddSymbolFunctions(isContrib: Boolean) extends StaticAnnota } private[mxnet] class AddSymbolAPIs(isContrib: Boolean) extends StaticAnnotation { - private[mxnet] def macroTransform(annottees: Any*) = macro SymbolImplMacros.addNewDefs + private[mxnet] def macroTransform(annottees: Any*) = macro SymbolImplMacros.typeSafeAPIDefs } private[mxnet] object SymbolImplMacros { @@ -40,7 +40,7 @@ private[mxnet] object SymbolImplMacros { def addDefs(c: blackbox.Context)(annottees: c.Expr[Any]*) = { impl(c)(annottees: _*) } - def addNewDefs(c: blackbox.Context)(annottees: c.Expr[Any]*) = { + def typeSafeAPIDefs(c: blackbox.Context)(annottees: c.Expr[Any]*) = { newAPIImpl(c)(annottees: _*) } // scalastyle:on havetype @@ -99,21 +99,6 @@ private[mxnet] object SymbolImplMacros { // Construct argument field var argDef = ListBuffer[String]() - symbolfunction.listOfArgs.foreach(symbolarg => { - val currArgName = symbolarg.argName match { - case "var" => "vari" - case "type" => "typeOf" - case default => symbolarg.argName - } - if (symbolarg.isOptional) { - argDef += s"${currArgName} : Option[${symbolarg.argType}] = None" - } - else { - argDef += s"${currArgName} : ${symbolarg.argType}" - } - }) - argDef += "name : String = null" - argDef += "attr : Map[String, String] = null" // Construct Implementation field var impl = ListBuffer[String]() impl += "val map = scala.collection.mutable.Map[String, Any]()" @@ -125,12 +110,20 @@ private[mxnet] object SymbolImplMacros { case "type" => "typeOf" case default => symbolarg.argName } + if (symbolarg.isOptional) { + argDef += s"${currArgName} : Option[${symbolarg.argType}] = None" + } + else { + argDef += s"${currArgName} : ${symbolarg.argType}" + } var base = "map(\"" + symbolarg.argName + "\") = " + currArgName if (symbolarg.isOptional) { base = "if (!" + currArgName + ".isEmpty)" + base + ".get" } impl += base }) + argDef += "name : String = null" + argDef += "attr : Map[String, String] = null" // scalastyle:off impl += "org.apache.mxnet.Symbol.createSymbolGeneral(\"" + symbolfunction.name + "\", name, attr, Seq(), map.toMap)" // scalastyle:on diff --git a/scala-package/macros/src/test/scala/org/apache/mxnet/MacrosSuite.scala b/scala-package/macros/src/test/scala/org/apache/mxnet/MacrosSuite.scala index 86e92f15f605..bc8be7df5fb1 100644 --- a/scala-package/macros/src/test/scala/org/apache/mxnet/MacrosSuite.scala +++ b/scala-package/macros/src/test/scala/org/apache/mxnet/MacrosSuite.scala @@ -25,13 +25,6 @@ class MacrosSuite extends FunSuite with BeforeAndAfterAll { private val logger = LoggerFactory.getLogger(classOf[MacrosSuite]) - override def beforeAll() { - } - - override def afterAll(): Unit = { - - } - test("MacrosSuite-testArgumentCleaner") { val input = List( "Symbol, optional, default = Null", From ba926fcb6f66f19810f4c3d9ddfe660728959e14 Mon Sep 17 00:00:00 2001 From: Qing Date: Fri, 11 May 2018 17:15:43 -0700 Subject: [PATCH 24/26] add TODOs and name changes --- .../init/src/main/scala/org/apache/mxnet/init/Base.scala | 4 ++-- scala-package/macros/pom.xml | 2 +- .../macros/src/main/scala/org/apache/mxnet/SymbolMacro.scala | 3 +++ 3 files changed, 6 insertions(+), 3 deletions(-) diff --git a/scala-package/init/src/main/scala/org/apache/mxnet/init/Base.scala b/scala-package/init/src/main/scala/org/apache/mxnet/init/Base.scala index 80b4bd0daaa8..94ccb47327a7 100644 --- a/scala-package/init/src/main/scala/org/apache/mxnet/init/Base.scala +++ b/scala-package/init/src/main/scala/org/apache/mxnet/init/Base.scala @@ -38,8 +38,8 @@ object Base { @throws(classOf[UnsatisfiedLinkError]) private def tryLoadInitLibrary(): Unit = { var baseDir = System.getProperty("user.dir") + "/init-native" - if (System.getenv().containsKey("MXNET_SCALA_MACRO_BASEDIR")) { - baseDir = sys.env("MXNET_SCALA_MACRO_BASEDIR") + if (System.getenv().containsKey("MXNET_BASEDIR")) { + baseDir = sys.env("MXNET_BASEDIR") } val os = System.getProperty("os.name") // ref: http://lopica.sourceforge.net/os.html diff --git a/scala-package/macros/pom.xml b/scala-package/macros/pom.xml index 8eea759c18a3..59cc181bd360 100644 --- a/scala-package/macros/pom.xml +++ b/scala-package/macros/pom.xml @@ -75,7 +75,7 @@ scalatest-maven-plugin - ${project.parent.basedir}/init-native + ${project.parent.basedir}/init-native -Djava.library.path=${project.parent.basedir}/native/${platform}/target \ diff --git a/scala-package/macros/src/main/scala/org/apache/mxnet/SymbolMacro.scala b/scala-package/macros/src/main/scala/org/apache/mxnet/SymbolMacro.scala index cc8d5444565b..234a8604cb91 100644 --- a/scala-package/macros/src/main/scala/org/apache/mxnet/SymbolMacro.scala +++ b/scala-package/macros/src/main/scala/org/apache/mxnet/SymbolMacro.scala @@ -89,6 +89,8 @@ private[mxnet] object SymbolImplMacros { case q"new AddSymbolAPIs($b)" => c.eval[Boolean](c.Expr(b)) } + // TODO: Put Symbol.api.foo --> Stable APIs + // Symbol.contrib.bar--> Contrib APIs val newSymbolFunctions = { if (isContrib) symbolFunctions.filter( func => func.name.startsWith("_contrib_") || !func.name.startsWith("_")) @@ -125,6 +127,7 @@ private[mxnet] object SymbolImplMacros { argDef += "name : String = null" argDef += "attr : Map[String, String] = null" // scalastyle:off + // TODO: Seq() here allows user to place Symbols rather than normal arguments to run, need to fix if old API deprecated impl += "org.apache.mxnet.Symbol.createSymbolGeneral(\"" + symbolfunction.name + "\", name, attr, Seq(), map.toMap)" // scalastyle:on // Combine and build the function string From 0e8a17470e164a45b77e884220ba9b9840d87a24 Mon Sep 17 00:00:00 2001 From: Naveen Swamy Date: Fri, 11 May 2018 18:08:21 -0700 Subject: [PATCH 25/26] Update Base.scala Add relative path to MXNET_BASEDIR --- .../init/src/main/scala/org/apache/mxnet/init/Base.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scala-package/init/src/main/scala/org/apache/mxnet/init/Base.scala b/scala-package/init/src/main/scala/org/apache/mxnet/init/Base.scala index 94ccb47327a7..825805ab6c6b 100644 --- a/scala-package/init/src/main/scala/org/apache/mxnet/init/Base.scala +++ b/scala-package/init/src/main/scala/org/apache/mxnet/init/Base.scala @@ -39,7 +39,7 @@ object Base { private def tryLoadInitLibrary(): Unit = { var baseDir = System.getProperty("user.dir") + "/init-native" if (System.getenv().containsKey("MXNET_BASEDIR")) { - baseDir = sys.env("MXNET_BASEDIR") + baseDir = sys.env("MXNET_BASEDIR") + "/scala-package/init-native" } val os = System.getProperty("os.name") // ref: http://lopica.sourceforge.net/os.html From c194066440c973b9533ef3a9b456f0af05473d53 Mon Sep 17 00:00:00 2001 From: Naveen Swamy Date: Fri, 11 May 2018 18:19:01 -0700 Subject: [PATCH 26/26] Update Base.scala --- .../init/src/main/scala/org/apache/mxnet/init/Base.scala | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/scala-package/init/src/main/scala/org/apache/mxnet/init/Base.scala b/scala-package/init/src/main/scala/org/apache/mxnet/init/Base.scala index 825805ab6c6b..7402dbd3bc1d 100644 --- a/scala-package/init/src/main/scala/org/apache/mxnet/init/Base.scala +++ b/scala-package/init/src/main/scala/org/apache/mxnet/init/Base.scala @@ -38,8 +38,10 @@ object Base { @throws(classOf[UnsatisfiedLinkError]) private def tryLoadInitLibrary(): Unit = { var baseDir = System.getProperty("user.dir") + "/init-native" + // TODO(lanKing520) Update this to use relative path to the MXNet director. + // TODO(lanking520) baseDir = sys.env("MXNET_BASEDIR") + "/scala-package/init-native" if (System.getenv().containsKey("MXNET_BASEDIR")) { - baseDir = sys.env("MXNET_BASEDIR") + "/scala-package/init-native" + baseDir = sys.env("MXNET_BASEDIR") } val os = System.getProperty("os.name") // ref: http://lopica.sourceforge.net/os.html