From 40817c0862bdf371a750481f5733f891c0c83620 Mon Sep 17 00:00:00 2001 From: Qing Date: Wed, 9 May 2018 15:19:39 -0700 Subject: [PATCH 01/20] Add initial files --- .../org/apache/mxnet/APIDocGenerator.scala | 124 ++++++++++++++++++ 1 file changed, 124 insertions(+) create mode 100644 scala-package/macros/src/main/scala/org/apache/mxnet/APIDocGenerator.scala diff --git a/scala-package/macros/src/main/scala/org/apache/mxnet/APIDocGenerator.scala b/scala-package/macros/src/main/scala/org/apache/mxnet/APIDocGenerator.scala new file mode 100644 index 000000000000..20479377ef5b --- /dev/null +++ b/scala-package/macros/src/main/scala/org/apache/mxnet/APIDocGenerator.scala @@ -0,0 +1,124 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mxnet + +import org.apache.mxnet.init.Base._ + +import scala.collection.mutable.ListBuffer + +private[mxnet] object APIDocGenerator{ + case class traitArg(argName : String, argType : String, argDesc : String, isOptional : Boolean) + case class traitFunction(name : String, listOfArgs: List[traitArg]) + + + + // Convert C++ Types to Scala Types + def typeConversion(in : String, argType : String = "", returnType : String) : String = { + in match { + case "Shape(tuple)" | "ShapeorNone" => "org.apache.mxnet.Shape" + case "Symbol" | "NDArray" | "NDArray-or-Symbol" => returnType + case "Symbol[]" | "NDArray[]" | "NDArray-or-Symbol[]" | "SymbolorSymbol[]" + => s"Array[$returnType]" + case "float" | "real_t" | "floatorNone" => "org.apache.mxnet.Base.MXFloat" + case "int" | "intorNone" | "int(non-negative)" => "Int" + case "long" | "long(non-negative)" => "Long" + case "double" | "doubleorNone" => "Double" + case "string" => "String" + case "boolean" => "Boolean" + case "tupleof" | "tupleof" | "ptr" | "" => "Any" + case default => throw new IllegalArgumentException( + s"Invalid type for args: $default, $argType") + } + } + + + /** + * By default, the argType come from the C++ API is a description more than a single word + * For Example: + * , , + * The three field shown above do not usually come at the same time + * This function used the above format to determine if the argument is + * optional, what is it Scala type and possibly pass in a default value + * @param argType Raw arguement Type description + * @return (Scala_Type, isOptional) + */ + def argumentCleaner(argType : String, returnType : String) : (String, Boolean) = { + val spaceRemoved = argType.replaceAll("\\s+", "") + var commaRemoved : Array[String] = new Array[String](0) + // Deal with the case e.g: stype : {'csr', 'default', 'row_sparse'} + if (spaceRemoved.charAt(0)== '{') { + val endIdx = spaceRemoved.indexOf('}') + commaRemoved = spaceRemoved.substring(endIdx + 1).split(",") + commaRemoved(0) = "string" + } else { + commaRemoved = spaceRemoved.split(",") + } + // Optional Field + if (commaRemoved.length >= 3) { + // arg: Type, optional, default = Null + require(commaRemoved(1).equals("optional")) + require(commaRemoved(2).startsWith("default=")) + (typeConversion(commaRemoved(0), argType, returnType), true) + } else if (commaRemoved.length == 2 || commaRemoved.length == 1) { + val tempType = typeConversion(commaRemoved(0), argType, returnType) + val tempOptional = tempType.equals("org.apache.mxnet.Symbol") + (tempType, tempOptional) + } else { + throw new IllegalArgumentException( + s"Unrecognized arg field: $argType, ${commaRemoved.length}") + } + + } + + + // List and add all the atomic symbol functions to current module. + private def initSymbolModule(isSymbol : Boolean): List[traitFunction] = { + val opNames = ListBuffer.empty[String] + val returnType = if (isSymbol) "Symbol" else "NDArray" + _LIB.mxListAllOpNames(opNames) + // TODO: Add '_linalg_', '_sparse_', '_image_' support + opNames.map(opName => { + val opHandle = new RefLong + _LIB.nnGetOpHandle(opName, opHandle) + makeAtomicSymbolFunction(opHandle.value, opName, "org.apache.mxnet." + returnType) + }).toList + } + + // Create an atomic symbol function by handle and function name. + private def makeAtomicSymbolFunction(handle: SymbolHandle, aliasName: String, returnType : String) + : traitFunction = { + val name = new RefString + val desc = new RefString + val keyVarNumArgs = new RefString + val numArgs = new RefInt + val argNames = ListBuffer.empty[String] + val argTypes = ListBuffer.empty[String] + val argDescs = ListBuffer.empty[String] + + _LIB.mxSymbolGetAtomicSymbolInfo( + handle, name, desc, numArgs, argNames, argTypes, argDescs, keyVarNumArgs) + + val realName = if (aliasName == name.value) "" else s"(a.k.a., ${name.value})" + + val argList = argNames zip argTypes zip argDescs map { case ((argName, argType), argDesc) => + val typeAndOption = argumentCleaner(argType, returnType) + new traitArg(argName, typeAndOption._1, argDesc, typeAndOption._2) + } + new traitFunction(aliasName, argList.toList) + } +} From 4c7999c5a82478bc9f7c270fdf17e4e0d5982d3f Mon Sep 17 00:00:00 2001 From: Qing Date: Thu, 10 May 2018 15:18:15 -0700 Subject: [PATCH 02/20] finish major implementing the DocGen --- .../org/apache/mxnet/APIDocGenerator.scala | 61 ++++++++++++++++++- 1 file changed, 59 insertions(+), 2 deletions(-) diff --git a/scala-package/macros/src/main/scala/org/apache/mxnet/APIDocGenerator.scala b/scala-package/macros/src/main/scala/org/apache/mxnet/APIDocGenerator.scala index 20479377ef5b..f691a8b838c9 100644 --- a/scala-package/macros/src/main/scala/org/apache/mxnet/APIDocGenerator.scala +++ b/scala-package/macros/src/main/scala/org/apache/mxnet/APIDocGenerator.scala @@ -23,9 +23,66 @@ import scala.collection.mutable.ListBuffer private[mxnet] object APIDocGenerator{ case class traitArg(argName : String, argType : String, argDesc : String, isOptional : Boolean) - case class traitFunction(name : String, listOfArgs: List[traitArg]) + case class traitFunction(name : String, desc : String, + listOfArgs: List[traitArg], returnType : String) + + val FILE_PATH = "" + + def traitGen() : Unit = { + val traitFunctions = initSymbolModule(true) + val traitfuncs = traitFunctions.map(traitfunction => { + val scalaDoc = ScalaDocGen(traitfunction) + val traitBody = traitBodyGen(traitfunction) + s"$scalaDoc\n$traitBody" + }) + // scalastyle: off + val apacheLicence = "/*\n* Licensed to the Apache Software Foundation (ASF) under one or more\n* contributor license agreements. See the NOTICE file distributed with\n* this work for additional information regarding copyright ownership.\n* The ASF licenses this file to You under the Apache License, Version 2.0\n* (the \"License\"); you may not use this file except in compliance with\n* the License. You may obtain a copy of the License at\n*\n* http://www.apache.org/licenses/LICENSE-2.0\n*\n* Unless required by applicable law or agreed to in writing, software\n* distributed under the License is distributed on an \"AS IS\" BASIS,\n* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n* See the License for the specific language governing permissions and\n* limitations under the License.\n*/\n" + // scalastyle: on + val packageDef = "package org.apache.mxnet" + val traitDef = "trait SymbolAPIBase" + val finalStr = s"$apacheLicence\n$packageDef\n$traitDef {\n${traitfuncs.mkString("\n")}\n}" + import java.io._ + val pw = new PrintWriter(new File(FILE_PATH)) + pw.write(finalStr) + pw.close() + } + // Generate ScalaDoc type + def ScalaDocGen(traitFunc : traitFunction) : String = { + val desc = traitFunc.desc.split("\n").map({ currStr => + s" * $currStr" + }) + val params = traitFunc.listOfArgs.map({ traitarg => + val currArgName = traitarg.argName match { + case "var" => "vari" + case "type" => "typeOf" + case _ => traitarg.argName + } + s" * @param $currArgName\t\t${traitarg.argDesc}" + }) + val returnType = s" * @return ${traitFunc.returnType}" + s" /**\n${desc.mkString("\n")}\n${params.mkString("\n")}\n$returnType\n */" + } + def traitBodyGen(traitFunc : traitFunction) : String = { + var argDef = ListBuffer[String]() + traitFunc.listOfArgs.foreach(traitarg => { + val currArgName = traitarg.argName match { + case "var" => "vari" + case "type" => "typeOf" + case _ => traitarg.argName + } + if (traitarg.isOptional) { + argDef += s"$currArgName : Option[${traitarg.argType}] = None" + } + else { + argDef += s"$currArgName : ${traitarg.argType}" + } + }) + argDef += "name : String = null" + argDef += "attr : Map[String, String] = null" + s"def ${traitFunc.name} (${argDef.mkString(", ")}) : ${traitFunc.returnType}" + } // Convert C++ Types to Scala Types def typeConversion(in : String, argType : String = "", returnType : String) : String = { @@ -119,6 +176,6 @@ private[mxnet] object APIDocGenerator{ val typeAndOption = argumentCleaner(argType, returnType) new traitArg(argName, typeAndOption._1, argDesc, typeAndOption._2) } - new traitFunction(aliasName, argList.toList) + new traitFunction(aliasName, desc.value, argList.toList, returnType) } } From 216ff42069bea2d210cf30cb89d1b5c42df463ba Mon Sep 17 00:00:00 2001 From: Qing Date: Thu, 10 May 2018 17:04:09 -0700 Subject: [PATCH 03/20] Hide under score function --- .../src/main/scala/org/apache/mxnet/APIDocGenerator.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scala-package/macros/src/main/scala/org/apache/mxnet/APIDocGenerator.scala b/scala-package/macros/src/main/scala/org/apache/mxnet/APIDocGenerator.scala index f691a8b838c9..7e7411b2bfc4 100644 --- a/scala-package/macros/src/main/scala/org/apache/mxnet/APIDocGenerator.scala +++ b/scala-package/macros/src/main/scala/org/apache/mxnet/APIDocGenerator.scala @@ -30,7 +30,7 @@ private[mxnet] object APIDocGenerator{ def traitGen() : Unit = { val traitFunctions = initSymbolModule(true) - val traitfuncs = traitFunctions.map(traitfunction => { + val traitfuncs = traitFunctions.filterNot(_.name.startsWith("_")).map(traitfunction => { val scalaDoc = ScalaDocGen(traitfunction) val traitBody = traitBodyGen(traitfunction) s"$scalaDoc\n$traitBody" From 3f9110427e94d18fbe9452c5cdb1acef36fbdcc7 Mon Sep 17 00:00:00 2001 From: Qing Date: Thu, 17 May 2018 13:28:36 -0700 Subject: [PATCH 04/20] add plugins that can execute the class during compile time --- scala-package/macros/pom.xml | 29 +++++++++++++++++++ .../org/apache/mxnet/APIDocGenerator.scala | 4 +++ 2 files changed, 33 insertions(+) diff --git a/scala-package/macros/pom.xml b/scala-package/macros/pom.xml index 59cc181bd360..6eae4cf7198b 100644 --- a/scala-package/macros/pom.xml +++ b/scala-package/macros/pom.xml @@ -70,6 +70,35 @@ org.apache.maven.plugins maven-compiler-plugin + + org.codehaus.mojo + exec-maven-plugin + 1.6.0 + + + my-execution + package + + java + + + + + + -Djava.library.path=${project.parent.basedir}/init-native/${platform}/target + -Dlog4j.configuration=file://${project.basedir}/src/test/resources/log4j.properties + + org.apache.mxnet.APIDocGenerator + + ${project.parent.basedir}/init-native + -Xmx8G + + + + + + + org.scalatest scalatest-maven-plugin diff --git a/scala-package/macros/src/main/scala/org/apache/mxnet/APIDocGenerator.scala b/scala-package/macros/src/main/scala/org/apache/mxnet/APIDocGenerator.scala index 7e7411b2bfc4..ee4d2d262d5d 100644 --- a/scala-package/macros/src/main/scala/org/apache/mxnet/APIDocGenerator.scala +++ b/scala-package/macros/src/main/scala/org/apache/mxnet/APIDocGenerator.scala @@ -28,6 +28,10 @@ private[mxnet] object APIDocGenerator{ val FILE_PATH = "" + def main(args: Array[String]) : Unit = { + traitGen() + } + def traitGen() : Unit = { val traitFunctions = initSymbolModule(true) val traitfuncs = traitFunctions.filterNot(_.name.startsWith("_")).map(traitfunction => { From 6a7e15b51ad5d8b7395994130bd1c3db9749dba6 Mon Sep 17 00:00:00 2001 From: Qing Date: Fri, 18 May 2018 12:41:46 -0700 Subject: [PATCH 05/20] Change the pom file and fix the running issues --- .../main/scala/org/apache/mxnet/SymbolAPI.scala | 2 +- scala-package/macros/pom.xml | 17 ++++++----------- .../org/apache/mxnet/APIDocGenerator.scala | 16 ++++++++-------- 3 files changed, 15 insertions(+), 20 deletions(-) diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/SymbolAPI.scala b/scala-package/core/src/main/scala/org/apache/mxnet/SymbolAPI.scala index 49de9ae73218..56da4fa64cf4 100644 --- a/scala-package/core/src/main/scala/org/apache/mxnet/SymbolAPI.scala +++ b/scala-package/core/src/main/scala/org/apache/mxnet/SymbolAPI.scala @@ -22,5 +22,5 @@ package org.apache.mxnet * typesafe Symbol API: Symbol.api._ * Main code will be generated during compile time through Macros */ -object SymbolAPI { +object SymbolAPI extends SymbolAPIBase { } diff --git a/scala-package/macros/pom.xml b/scala-package/macros/pom.xml index 6eae4cf7198b..15b846a1cf0e 100644 --- a/scala-package/macros/pom.xml +++ b/scala-package/macros/pom.xml @@ -53,6 +53,7 @@ + @@ -76,7 +77,7 @@ 1.6.0 - my-execution + apidoc-generation package java @@ -84,19 +85,13 @@ + + ${project.parent.basedir}/init/target/classes + - -Djava.library.path=${project.parent.basedir}/init-native/${platform}/target - -Dlog4j.configuration=file://${project.basedir}/src/test/resources/log4j.properties + ${project.parent.basedir}/core/src/main/scala/org/apache/mxnet/ org.apache.mxnet.APIDocGenerator - - ${project.parent.basedir}/init-native - -Xmx8G - - - - - diff --git a/scala-package/macros/src/main/scala/org/apache/mxnet/APIDocGenerator.scala b/scala-package/macros/src/main/scala/org/apache/mxnet/APIDocGenerator.scala index ee4d2d262d5d..3755854bd8e3 100644 --- a/scala-package/macros/src/main/scala/org/apache/mxnet/APIDocGenerator.scala +++ b/scala-package/macros/src/main/scala/org/apache/mxnet/APIDocGenerator.scala @@ -26,27 +26,27 @@ private[mxnet] object APIDocGenerator{ case class traitFunction(name : String, desc : String, listOfArgs: List[traitArg], returnType : String) - val FILE_PATH = "" def main(args: Array[String]) : Unit = { - traitGen() + val FILE_PATH = args(0) + traitGen(FILE_PATH) } - def traitGen() : Unit = { + def traitGen(FILE_PATH : String) : Unit = { + // scalastyle:off val traitFunctions = initSymbolModule(true) val traitfuncs = traitFunctions.filterNot(_.name.startsWith("_")).map(traitfunction => { val scalaDoc = ScalaDocGen(traitfunction) val traitBody = traitBodyGen(traitfunction) s"$scalaDoc\n$traitBody" }) - // scalastyle: off val apacheLicence = "/*\n* Licensed to the Apache Software Foundation (ASF) under one or more\n* contributor license agreements. See the NOTICE file distributed with\n* this work for additional information regarding copyright ownership.\n* The ASF licenses this file to You under the Apache License, Version 2.0\n* (the \"License\"); you may not use this file except in compliance with\n* the License. You may obtain a copy of the License at\n*\n* http://www.apache.org/licenses/LICENSE-2.0\n*\n* Unless required by applicable law or agreed to in writing, software\n* distributed under the License is distributed on an \"AS IS\" BASIS,\n* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n* See the License for the specific language governing permissions and\n* limitations under the License.\n*/\n" - // scalastyle: on + val scalaStyle = "// scalastyle:off" val packageDef = "package org.apache.mxnet" - val traitDef = "trait SymbolAPIBase" - val finalStr = s"$apacheLicence\n$packageDef\n$traitDef {\n${traitfuncs.mkString("\n")}\n}" + val absClassDef = "abstract class SymbolAPIBase" + val finalStr = s"$apacheLicence\n$scalaStyle\n$packageDef\n$absClassDef {\n${traitfuncs.mkString("\n")}\n}" import java.io._ - val pw = new PrintWriter(new File(FILE_PATH)) + val pw = new PrintWriter(new File(FILE_PATH + "SymbolAPIBase.scala")) pw.write(finalStr) pw.close() } From bcb439a2c466512981fefbda65d7d1ff6f439938 Mon Sep 17 00:00:00 2001 From: Qing Date: Mon, 21 May 2018 17:28:08 -0700 Subject: [PATCH 06/20] reTrigger CI From 28cf54892b5e086771c47b586a5ae808148d7202 Mon Sep 17 00:00:00 2001 From: Qing Date: Tue, 22 May 2018 10:38:03 -0700 Subject: [PATCH 07/20] Add NDArray Support --- .../org/apache/mxnet/APIDocGenerator.scala | 47 ++++++++++--------- 1 file changed, 26 insertions(+), 21 deletions(-) diff --git a/scala-package/macros/src/main/scala/org/apache/mxnet/APIDocGenerator.scala b/scala-package/macros/src/main/scala/org/apache/mxnet/APIDocGenerator.scala index 3755854bd8e3..dbe08a1516c2 100644 --- a/scala-package/macros/src/main/scala/org/apache/mxnet/APIDocGenerator.scala +++ b/scala-package/macros/src/main/scala/org/apache/mxnet/APIDocGenerator.scala @@ -29,24 +29,26 @@ private[mxnet] object APIDocGenerator{ def main(args: Array[String]) : Unit = { val FILE_PATH = args(0) - traitGen(FILE_PATH) + absClassGen(FILE_PATH, true) + absClassGen(FILE_PATH, false) } - def traitGen(FILE_PATH : String) : Unit = { + def absClassGen(FILE_PATH : String, isSymbol : Boolean) : Unit = { // scalastyle:off - val traitFunctions = initSymbolModule(true) + val traitFunctions = initSymbolModule(isSymbol) val traitfuncs = traitFunctions.filterNot(_.name.startsWith("_")).map(traitfunction => { val scalaDoc = ScalaDocGen(traitfunction) - val traitBody = traitBodyGen(traitfunction) + val traitBody = defBodyGen(traitfunction, isSymbol) s"$scalaDoc\n$traitBody" }) + val packageName = if (isSymbol) "SymbolAPIBase" else "NDArrayAPIBase" val apacheLicence = "/*\n* Licensed to the Apache Software Foundation (ASF) under one or more\n* contributor license agreements. See the NOTICE file distributed with\n* this work for additional information regarding copyright ownership.\n* The ASF licenses this file to You under the Apache License, Version 2.0\n* (the \"License\"); you may not use this file except in compliance with\n* the License. You may obtain a copy of the License at\n*\n* http://www.apache.org/licenses/LICENSE-2.0\n*\n* Unless required by applicable law or agreed to in writing, software\n* distributed under the License is distributed on an \"AS IS\" BASIS,\n* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n* See the License for the specific language governing permissions and\n* limitations under the License.\n*/\n" val scalaStyle = "// scalastyle:off" val packageDef = "package org.apache.mxnet" - val absClassDef = "abstract class SymbolAPIBase" + val absClassDef = s"abstract class $packageName" val finalStr = s"$apacheLicence\n$scalaStyle\n$packageDef\n$absClassDef {\n${traitfuncs.mkString("\n")}\n}" import java.io._ - val pw = new PrintWriter(new File(FILE_PATH + "SymbolAPIBase.scala")) + val pw = new PrintWriter(new File(FILE_PATH + s"$packageName.scala")) pw.write(finalStr) pw.close() } @@ -68,26 +70,29 @@ private[mxnet] object APIDocGenerator{ s" /**\n${desc.mkString("\n")}\n${params.mkString("\n")}\n$returnType\n */" } - def traitBodyGen(traitFunc : traitFunction) : String = { + def defBodyGen(traitFunc : traitFunction, isSymbol : Boolean) : String = { var argDef = ListBuffer[String]() traitFunc.listOfArgs.foreach(traitarg => { - val currArgName = traitarg.argName match { - case "var" => "vari" - case "type" => "typeOf" - case _ => traitarg.argName - } - if (traitarg.isOptional) { - argDef += s"$currArgName : Option[${traitarg.argType}] = None" - } - else { - argDef += s"$currArgName : ${traitarg.argType}" - } - }) - argDef += "name : String = null" - argDef += "attr : Map[String, String] = null" + val currArgName = traitarg.argName match { + case "var" => "vari" + case "type" => "typeOf" + case _ => traitarg.argName + } + if (traitarg.isOptional) { + argDef += s"$currArgName : Option[${traitarg.argType}] = None" + } + else { + argDef += s"$currArgName : ${traitarg.argType}" + } + }) + if (isSymbol) { + argDef += "name : String = null" + argDef += "attr : Map[String, String] = null" + } s"def ${traitFunc.name} (${argDef.mkString(", ")}) : ${traitFunc.returnType}" } + // Convert C++ Types to Scala Types def typeConversion(in : String, argType : String = "", returnType : String) : String = { in match { From e33f3fb824c3d4b4d1ca76c488bb2d29bc68a94d Mon Sep 17 00:00:00 2001 From: Qing Date: Tue, 22 May 2018 11:13:31 -0700 Subject: [PATCH 08/20] Change trait to absClass --- .../org/apache/mxnet/APIDocGenerator.scala | 52 +++++++++---------- 1 file changed, 26 insertions(+), 26 deletions(-) diff --git a/scala-package/macros/src/main/scala/org/apache/mxnet/APIDocGenerator.scala b/scala-package/macros/src/main/scala/org/apache/mxnet/APIDocGenerator.scala index dbe08a1516c2..32a01089f8dc 100644 --- a/scala-package/macros/src/main/scala/org/apache/mxnet/APIDocGenerator.scala +++ b/scala-package/macros/src/main/scala/org/apache/mxnet/APIDocGenerator.scala @@ -22,9 +22,9 @@ import org.apache.mxnet.init.Base._ import scala.collection.mutable.ListBuffer private[mxnet] object APIDocGenerator{ - case class traitArg(argName : String, argType : String, argDesc : String, isOptional : Boolean) - case class traitFunction(name : String, desc : String, - listOfArgs: List[traitArg], returnType : String) + case class absClassArg(argName : String, argType : String, argDesc : String, isOptional : Boolean) + case class absClassFunction(name : String, desc : String, + listOfArgs: List[absClassArg], returnType : String) def main(args: Array[String]) : Unit = { @@ -35,18 +35,18 @@ private[mxnet] object APIDocGenerator{ def absClassGen(FILE_PATH : String, isSymbol : Boolean) : Unit = { // scalastyle:off - val traitFunctions = initSymbolModule(isSymbol) - val traitfuncs = traitFunctions.filterNot(_.name.startsWith("_")).map(traitfunction => { - val scalaDoc = ScalaDocGen(traitfunction) - val traitBody = defBodyGen(traitfunction, isSymbol) - s"$scalaDoc\n$traitBody" + val absClassFunctions = initSymbolModule(isSymbol) + val absFuncs = absClassFunctions.filterNot(_.name.startsWith("_")).map(absClassFunction => { + val scalaDoc = ScalaDocGen(absClassFunction) + val defBody = defBodyGen(absClassFunction, isSymbol) + s"$scalaDoc\n$defBody" }) val packageName = if (isSymbol) "SymbolAPIBase" else "NDArrayAPIBase" val apacheLicence = "/*\n* Licensed to the Apache Software Foundation (ASF) under one or more\n* contributor license agreements. See the NOTICE file distributed with\n* this work for additional information regarding copyright ownership.\n* The ASF licenses this file to You under the Apache License, Version 2.0\n* (the \"License\"); you may not use this file except in compliance with\n* the License. You may obtain a copy of the License at\n*\n* http://www.apache.org/licenses/LICENSE-2.0\n*\n* Unless required by applicable law or agreed to in writing, software\n* distributed under the License is distributed on an \"AS IS\" BASIS,\n* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n* See the License for the specific language governing permissions and\n* limitations under the License.\n*/\n" val scalaStyle = "// scalastyle:off" val packageDef = "package org.apache.mxnet" val absClassDef = s"abstract class $packageName" - val finalStr = s"$apacheLicence\n$scalaStyle\n$packageDef\n$absClassDef {\n${traitfuncs.mkString("\n")}\n}" + val finalStr = s"$apacheLicence\n$scalaStyle\n$packageDef\n$absClassDef {\n${absFuncs.mkString("\n")}\n}" import java.io._ val pw = new PrintWriter(new File(FILE_PATH + s"$packageName.scala")) pw.write(finalStr) @@ -54,42 +54,42 @@ private[mxnet] object APIDocGenerator{ } // Generate ScalaDoc type - def ScalaDocGen(traitFunc : traitFunction) : String = { + def ScalaDocGen(traitFunc : absClassFunction) : String = { val desc = traitFunc.desc.split("\n").map({ currStr => s" * $currStr" }) - val params = traitFunc.listOfArgs.map({ traitarg => - val currArgName = traitarg.argName match { + val params = traitFunc.listOfArgs.map({ absClassArg => + val currArgName = absClassArg.argName match { case "var" => "vari" case "type" => "typeOf" - case _ => traitarg.argName + case _ => absClassArg.argName } - s" * @param $currArgName\t\t${traitarg.argDesc}" + s" * @param $currArgName\t\t${absClassArg.argDesc}" }) val returnType = s" * @return ${traitFunc.returnType}" s" /**\n${desc.mkString("\n")}\n${params.mkString("\n")}\n$returnType\n */" } - def defBodyGen(traitFunc : traitFunction, isSymbol : Boolean) : String = { + def defBodyGen(absClassFunc : absClassFunction, isSymbol : Boolean) : String = { var argDef = ListBuffer[String]() - traitFunc.listOfArgs.foreach(traitarg => { - val currArgName = traitarg.argName match { + absClassFunc.listOfArgs.foreach(absClassArg => { + val currArgName = absClassArg.argName match { case "var" => "vari" case "type" => "typeOf" - case _ => traitarg.argName + case _ => absClassArg.argName } - if (traitarg.isOptional) { - argDef += s"$currArgName : Option[${traitarg.argType}] = None" + if (absClassArg.isOptional) { + argDef += s"$currArgName : Option[${absClassArg.argType}] = None" } else { - argDef += s"$currArgName : ${traitarg.argType}" + argDef += s"$currArgName : ${absClassArg.argType}" } }) if (isSymbol) { argDef += "name : String = null" argDef += "attr : Map[String, String] = null" } - s"def ${traitFunc.name} (${argDef.mkString(", ")}) : ${traitFunc.returnType}" + s"def ${absClassFunc.name} (${argDef.mkString(", ")}) : ${absClassFunc.returnType}" } @@ -153,7 +153,7 @@ private[mxnet] object APIDocGenerator{ // List and add all the atomic symbol functions to current module. - private def initSymbolModule(isSymbol : Boolean): List[traitFunction] = { + private def initSymbolModule(isSymbol : Boolean): List[absClassFunction] = { val opNames = ListBuffer.empty[String] val returnType = if (isSymbol) "Symbol" else "NDArray" _LIB.mxListAllOpNames(opNames) @@ -167,7 +167,7 @@ private[mxnet] object APIDocGenerator{ // Create an atomic symbol function by handle and function name. private def makeAtomicSymbolFunction(handle: SymbolHandle, aliasName: String, returnType : String) - : traitFunction = { + : absClassFunction = { val name = new RefString val desc = new RefString val keyVarNumArgs = new RefString @@ -183,8 +183,8 @@ private[mxnet] object APIDocGenerator{ val argList = argNames zip argTypes zip argDescs map { case ((argName, argType), argDesc) => val typeAndOption = argumentCleaner(argType, returnType) - new traitArg(argName, typeAndOption._1, argDesc, typeAndOption._2) + new absClassArg(argName, typeAndOption._1, argDesc, typeAndOption._2) } - new traitFunction(aliasName, desc.value, argList.toList, returnType) + new absClassFunction(aliasName, desc.value, argList.toList, returnType) } } From 5f8191d4323443bf38dd60d9635369b14ceb2beb Mon Sep 17 00:00:00 2001 From: Qing Date: Tue, 22 May 2018 11:30:25 -0700 Subject: [PATCH 09/20] Change names and add comments --- .../org/apache/mxnet/APIDocGenerator.scala | 17 +++++++++++------ 1 file changed, 11 insertions(+), 6 deletions(-) diff --git a/scala-package/macros/src/main/scala/org/apache/mxnet/APIDocGenerator.scala b/scala-package/macros/src/main/scala/org/apache/mxnet/APIDocGenerator.scala index 32a01089f8dc..b3d5ae8ea1fe 100644 --- a/scala-package/macros/src/main/scala/org/apache/mxnet/APIDocGenerator.scala +++ b/scala-package/macros/src/main/scala/org/apache/mxnet/APIDocGenerator.scala @@ -21,6 +21,11 @@ import org.apache.mxnet.init.Base._ import scala.collection.mutable.ListBuffer +/** + * This object will generate the Scala documentation of the new Scala API + * Two file namely: SymbolAPIBase.scala and NDArrayAPIBase.scala + * The code will be executed during Macros stage and file live in Core stage + */ private[mxnet] object APIDocGenerator{ case class absClassArg(argName : String, argType : String, argDesc : String, isOptional : Boolean) case class absClassFunction(name : String, desc : String, @@ -35,10 +40,10 @@ private[mxnet] object APIDocGenerator{ def absClassGen(FILE_PATH : String, isSymbol : Boolean) : Unit = { // scalastyle:off - val absClassFunctions = initSymbolModule(isSymbol) + val absClassFunctions = getSymbolNDArrayMethods(isSymbol) val absFuncs = absClassFunctions.filterNot(_.name.startsWith("_")).map(absClassFunction => { - val scalaDoc = ScalaDocGen(absClassFunction) - val defBody = defBodyGen(absClassFunction, isSymbol) + val scalaDoc = generateAPIDocFromBackend(absClassFunction) + val defBody = generateAPISignature(absClassFunction, isSymbol) s"$scalaDoc\n$defBody" }) val packageName = if (isSymbol) "SymbolAPIBase" else "NDArrayAPIBase" @@ -54,7 +59,7 @@ private[mxnet] object APIDocGenerator{ } // Generate ScalaDoc type - def ScalaDocGen(traitFunc : absClassFunction) : String = { + def generateAPIDocFromBackend(traitFunc : absClassFunction) : String = { val desc = traitFunc.desc.split("\n").map({ currStr => s" * $currStr" }) @@ -70,7 +75,7 @@ private[mxnet] object APIDocGenerator{ s" /**\n${desc.mkString("\n")}\n${params.mkString("\n")}\n$returnType\n */" } - def defBodyGen(absClassFunc : absClassFunction, isSymbol : Boolean) : String = { + def generateAPISignature(absClassFunc : absClassFunction, isSymbol : Boolean) : String = { var argDef = ListBuffer[String]() absClassFunc.listOfArgs.foreach(absClassArg => { val currArgName = absClassArg.argName match { @@ -153,7 +158,7 @@ private[mxnet] object APIDocGenerator{ // List and add all the atomic symbol functions to current module. - private def initSymbolModule(isSymbol : Boolean): List[absClassFunction] = { + private def getSymbolNDArrayMethods(isSymbol : Boolean): List[absClassFunction] = { val opNames = ListBuffer.empty[String] val returnType = if (isSymbol) "Symbol" else "NDArray" _LIB.mxListAllOpNames(opNames) From 2fd91dadcb267daae0b7f59b20c50145c5f822a0 Mon Sep 17 00:00:00 2001 From: Qing Date: Tue, 22 May 2018 11:46:31 -0700 Subject: [PATCH 10/20] Move the CtoScala Converter to an individual place --- .../org/apache/mxnet/APIDocGenerator.scala | 63 +-------------- .../scala/org/apache/mxnet/SymbolMacro.scala | 63 +-------------- .../apache/mxnet/utils/CToScalaUtils.scala | 80 +++++++++++++++++++ .../scala/org/apache/mxnet/MacrosSuite.scala | 3 +- 4 files changed, 87 insertions(+), 122 deletions(-) create mode 100644 scala-package/macros/src/main/scala/org/apache/mxnet/utils/CToScalaUtils.scala diff --git a/scala-package/macros/src/main/scala/org/apache/mxnet/APIDocGenerator.scala b/scala-package/macros/src/main/scala/org/apache/mxnet/APIDocGenerator.scala index b3d5ae8ea1fe..1630d41917f0 100644 --- a/scala-package/macros/src/main/scala/org/apache/mxnet/APIDocGenerator.scala +++ b/scala-package/macros/src/main/scala/org/apache/mxnet/APIDocGenerator.scala @@ -18,6 +18,7 @@ package org.apache.mxnet import org.apache.mxnet.init.Base._ +import org.apache.mxnet.utils.CToScalaUtils import scala.collection.mutable.ListBuffer @@ -41,6 +42,7 @@ private[mxnet] object APIDocGenerator{ def absClassGen(FILE_PATH : String, isSymbol : Boolean) : Unit = { // scalastyle:off val absClassFunctions = getSymbolNDArrayMethods(isSymbol) + // TODO: Add Filter to the same location in case of refactor val absFuncs = absClassFunctions.filterNot(_.name.startsWith("_")).map(absClassFunction => { val scalaDoc = generateAPIDocFromBackend(absClassFunction) val defBody = generateAPISignature(absClassFunction, isSymbol) @@ -98,65 +100,6 @@ private[mxnet] object APIDocGenerator{ } - // Convert C++ Types to Scala Types - def typeConversion(in : String, argType : String = "", returnType : String) : String = { - in match { - case "Shape(tuple)" | "ShapeorNone" => "org.apache.mxnet.Shape" - case "Symbol" | "NDArray" | "NDArray-or-Symbol" => returnType - case "Symbol[]" | "NDArray[]" | "NDArray-or-Symbol[]" | "SymbolorSymbol[]" - => s"Array[$returnType]" - case "float" | "real_t" | "floatorNone" => "org.apache.mxnet.Base.MXFloat" - case "int" | "intorNone" | "int(non-negative)" => "Int" - case "long" | "long(non-negative)" => "Long" - case "double" | "doubleorNone" => "Double" - case "string" => "String" - case "boolean" => "Boolean" - case "tupleof" | "tupleof" | "ptr" | "" => "Any" - case default => throw new IllegalArgumentException( - s"Invalid type for args: $default, $argType") - } - } - - - /** - * By default, the argType come from the C++ API is a description more than a single word - * For Example: - * , , - * The three field shown above do not usually come at the same time - * This function used the above format to determine if the argument is - * optional, what is it Scala type and possibly pass in a default value - * @param argType Raw arguement Type description - * @return (Scala_Type, isOptional) - */ - def argumentCleaner(argType : String, returnType : String) : (String, Boolean) = { - val spaceRemoved = argType.replaceAll("\\s+", "") - var commaRemoved : Array[String] = new Array[String](0) - // Deal with the case e.g: stype : {'csr', 'default', 'row_sparse'} - if (spaceRemoved.charAt(0)== '{') { - val endIdx = spaceRemoved.indexOf('}') - commaRemoved = spaceRemoved.substring(endIdx + 1).split(",") - commaRemoved(0) = "string" - } else { - commaRemoved = spaceRemoved.split(",") - } - // Optional Field - if (commaRemoved.length >= 3) { - // arg: Type, optional, default = Null - require(commaRemoved(1).equals("optional")) - require(commaRemoved(2).startsWith("default=")) - (typeConversion(commaRemoved(0), argType, returnType), true) - } else if (commaRemoved.length == 2 || commaRemoved.length == 1) { - val tempType = typeConversion(commaRemoved(0), argType, returnType) - val tempOptional = tempType.equals("org.apache.mxnet.Symbol") - (tempType, tempOptional) - } else { - throw new IllegalArgumentException( - s"Unrecognized arg field: $argType, ${commaRemoved.length}") - } - - } - - // List and add all the atomic symbol functions to current module. private def getSymbolNDArrayMethods(isSymbol : Boolean): List[absClassFunction] = { val opNames = ListBuffer.empty[String] @@ -187,7 +130,7 @@ private[mxnet] object APIDocGenerator{ val realName = if (aliasName == name.value) "" else s"(a.k.a., ${name.value})" val argList = argNames zip argTypes zip argDescs map { case ((argName, argType), argDesc) => - val typeAndOption = argumentCleaner(argType, returnType) + val typeAndOption = CToScalaUtils.argumentCleaner(argType, returnType) new absClassArg(argName, typeAndOption._1, argDesc, typeAndOption._2) } new absClassFunction(aliasName, desc.value, argList.toList, returnType) diff --git a/scala-package/macros/src/main/scala/org/apache/mxnet/SymbolMacro.scala b/scala-package/macros/src/main/scala/org/apache/mxnet/SymbolMacro.scala index 234a8604cb91..bacbdb2e3075 100644 --- a/scala-package/macros/src/main/scala/org/apache/mxnet/SymbolMacro.scala +++ b/scala-package/macros/src/main/scala/org/apache/mxnet/SymbolMacro.scala @@ -22,7 +22,7 @@ import scala.collection.mutable.ListBuffer import scala.language.experimental.macros import scala.reflect.macros.blackbox import org.apache.mxnet.init.Base._ -import org.apache.mxnet.utils.OperatorBuildUtils +import org.apache.mxnet.utils.{CToScalaUtils, OperatorBuildUtils} private[mxnet] class AddSymbolFunctions(isContrib: Boolean) extends StaticAnnotation { private[mxnet] def macroTransform(annottees: Any*) = macro SymbolImplMacros.addDefs @@ -178,65 +178,6 @@ private[mxnet] object SymbolImplMacros { result } - // Convert C++ Types to Scala Types - def typeConversion(in : String, argType : String = "") : String = { - in match { - case "Shape(tuple)" | "ShapeorNone" => "org.apache.mxnet.Shape" - case "Symbol" | "NDArray" | "NDArray-or-Symbol" => "org.apache.mxnet.Symbol" - case "Symbol[]" | "NDArray[]" | "NDArray-or-Symbol[]" | "SymbolorSymbol[]" - => "Array[org.apache.mxnet.Symbol]" - case "float" | "real_t" | "floatorNone" => "org.apache.mxnet.Base.MXFloat" - case "int" | "intorNone" | "int(non-negative)" => "Int" - case "long" | "long(non-negative)" => "Long" - case "double" | "doubleorNone" => "Double" - case "string" => "String" - case "boolean" => "Boolean" - case "tupleof" | "tupleof" | "ptr" | "" => "Any" - case default => throw new IllegalArgumentException( - s"Invalid type for args: $default, $argType") - } - } - - - /** - * By default, the argType come from the C++ API is a description more than a single word - * For Example: - * , , - * The three field shown above do not usually come at the same time - * This function used the above format to determine if the argument is - * optional, what is it Scala type and possibly pass in a default value - * @param argType Raw arguement Type description - * @return (Scala_Type, isOptional) - */ - def argumentCleaner(argType : String) : (String, Boolean) = { - val spaceRemoved = argType.replaceAll("\\s+", "") - var commaRemoved : Array[String] = new Array[String](0) - // Deal with the case e.g: stype : {'csr', 'default', 'row_sparse'} - if (spaceRemoved.charAt(0)== '{') { - val endIdx = spaceRemoved.indexOf('}') - commaRemoved = spaceRemoved.substring(endIdx + 1).split(",") - commaRemoved(0) = "string" - } else { - commaRemoved = spaceRemoved.split(",") - } - // Optional Field - if (commaRemoved.length >= 3) { - // arg: Type, optional, default = Null - require(commaRemoved(1).equals("optional")) - require(commaRemoved(2).startsWith("default=")) - (typeConversion(commaRemoved(0), argType), true) - } else if (commaRemoved.length == 2 || commaRemoved.length == 1) { - val tempType = typeConversion(commaRemoved(0), argType) - val tempOptional = tempType.equals("org.apache.mxnet.Symbol") - (tempType, tempOptional) - } else { - throw new IllegalArgumentException( - s"Unrecognized arg field: $argType, ${commaRemoved.length}") - } - - } - - // List and add all the atomic symbol functions to current module. private def initSymbolModule(): List[SymbolFunction] = { val opNames = ListBuffer.empty[String] @@ -277,7 +218,7 @@ private[mxnet] object SymbolImplMacros { } // scalastyle:on println val argList = argNames zip argTypes map { case (argName, argType) => - val typeAndOption = argumentCleaner(argType) + val typeAndOption = CToScalaUtils.argumentCleaner(argType, "org.apache.mxnet.Symbol") new SymbolArg(argName, typeAndOption._1, typeAndOption._2) } new SymbolFunction(aliasName, argList.toList) diff --git a/scala-package/macros/src/main/scala/org/apache/mxnet/utils/CToScalaUtils.scala b/scala-package/macros/src/main/scala/org/apache/mxnet/utils/CToScalaUtils.scala new file mode 100644 index 000000000000..b7c4fb8560b6 --- /dev/null +++ b/scala-package/macros/src/main/scala/org/apache/mxnet/utils/CToScalaUtils.scala @@ -0,0 +1,80 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.mxnet.utils + +private[mxnet] object CToScalaUtils { + + + + // Convert C++ Types to Scala Types + def typeConversion(in : String, argType : String = "", returnType : String) : String = { + in match { + case "Shape(tuple)" | "ShapeorNone" => "org.apache.mxnet.Shape" + case "Symbol" | "NDArray" | "NDArray-or-Symbol" => returnType + case "Symbol[]" | "NDArray[]" | "NDArray-or-Symbol[]" | "SymbolorSymbol[]" + => s"Array[$returnType]" + case "float" | "real_t" | "floatorNone" => "org.apache.mxnet.Base.MXFloat" + case "int" | "intorNone" | "int(non-negative)" => "Int" + case "long" | "long(non-negative)" => "Long" + case "double" | "doubleorNone" => "Double" + case "string" => "String" + case "boolean" => "Boolean" + case "tupleof" | "tupleof" | "ptr" | "" => "Any" + case default => throw new IllegalArgumentException( + s"Invalid type for args: $default, $argType") + } + } + + + /** + * By default, the argType come from the C++ API is a description more than a single word + * For Example: + * , , + * The three field shown above do not usually come at the same time + * This function used the above format to determine if the argument is + * optional, what is it Scala type and possibly pass in a default value + * @param argType Raw arguement Type description + * @return (Scala_Type, isOptional) + */ + def argumentCleaner(argType : String, returnType : String) : (String, Boolean) = { + val spaceRemoved = argType.replaceAll("\\s+", "") + var commaRemoved : Array[String] = new Array[String](0) + // Deal with the case e.g: stype : {'csr', 'default', 'row_sparse'} + if (spaceRemoved.charAt(0)== '{') { + val endIdx = spaceRemoved.indexOf('}') + commaRemoved = spaceRemoved.substring(endIdx + 1).split(",") + commaRemoved(0) = "string" + } else { + commaRemoved = spaceRemoved.split(",") + } + // Optional Field + if (commaRemoved.length >= 3) { + // arg: Type, optional, default = Null + require(commaRemoved(1).equals("optional")) + require(commaRemoved(2).startsWith("default=")) + (typeConversion(commaRemoved(0), argType, returnType), true) + } else if (commaRemoved.length == 2 || commaRemoved.length == 1) { + val tempType = typeConversion(commaRemoved(0), argType, returnType) + val tempOptional = tempType.equals("org.apache.mxnet.Symbol") + (tempType, tempOptional) + } else { + throw new IllegalArgumentException( + s"Unrecognized arg field: $argType, ${commaRemoved.length}") + } + + } +} diff --git a/scala-package/macros/src/test/scala/org/apache/mxnet/MacrosSuite.scala b/scala-package/macros/src/test/scala/org/apache/mxnet/MacrosSuite.scala index bc8be7df5fb1..5883a00c3315 100644 --- a/scala-package/macros/src/test/scala/org/apache/mxnet/MacrosSuite.scala +++ b/scala-package/macros/src/test/scala/org/apache/mxnet/MacrosSuite.scala @@ -17,6 +17,7 @@ package org.apache.mxnet +import org.apache.mxnet.utils.CToScalaUtils import org.scalatest.{BeforeAndAfterAll, FunSuite} import org.slf4j.LoggerFactory @@ -42,7 +43,7 @@ class MacrosSuite extends FunSuite with BeforeAndAfterAll { ) for (idx <- input.indices) { - val result = SymbolImplMacros.argumentCleaner(input(idx)) + val result = CToScalaUtils.argumentCleaner(input(idx), "org.apache.mxnet.Symbol") assert(result._1 === output(idx)._1 && result._2 === output(idx)._2) } } From ae62897b6dab6f1e4b1e1e825a8ea9885570a02a Mon Sep 17 00:00:00 2001 From: Qing Date: Tue, 22 May 2018 13:46:08 -0700 Subject: [PATCH 11/20] Trigger CI From 7350091f9ae129ce23e04915f14f02c22f1868ce Mon Sep 17 00:00:00 2001 From: Qing Date: Tue, 22 May 2018 16:46:54 -0700 Subject: [PATCH 12/20] add image classification issues changes Fix issues that live inside of the image classification and object detection --- .../infer/imageclassifier/ImageClassifierExample.scala | 4 +++- .../infer/objectdetector/SSDClassifierExample.scala | 4 +++- .../main/scala/org/apache/mxnet/infer/ImageClassifier.scala | 2 +- 3 files changed, 7 insertions(+), 3 deletions(-) diff --git a/scala-package/examples/src/main/scala/org/apache/mxnetexamples/infer/imageclassifier/ImageClassifierExample.scala b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/infer/imageclassifier/ImageClassifierExample.scala index 0ee0c119e439..bec5427da55c 100644 --- a/scala-package/examples/src/main/scala/org/apache/mxnetexamples/infer/imageclassifier/ImageClassifierExample.scala +++ b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/infer/imageclassifier/ImageClassifierExample.scala @@ -99,7 +99,9 @@ object ImageClassifierExample { batch = ListBuffer[String]() } } - output += batch.toList + if (batch.length > 0) { + output += batch.toList + } output.toList } diff --git a/scala-package/examples/src/main/scala/org/apache/mxnetexamples/infer/objectdetector/SSDClassifierExample.scala b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/infer/objectdetector/SSDClassifierExample.scala index f4f7f5897893..f46e682fbf9a 100644 --- a/scala-package/examples/src/main/scala/org/apache/mxnetexamples/infer/objectdetector/SSDClassifierExample.scala +++ b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/infer/objectdetector/SSDClassifierExample.scala @@ -100,7 +100,9 @@ object SSDClassifierExample { batch = ListBuffer[String]() } } - output += batch.toList + if (batch.length > 0) { + output += batch.toList + } output.toList } diff --git a/scala-package/infer/src/main/scala/org/apache/mxnet/infer/ImageClassifier.scala b/scala-package/infer/src/main/scala/org/apache/mxnet/infer/ImageClassifier.scala index 6fa313b3e7f2..8d31d1f6b3d6 100644 --- a/scala-package/infer/src/main/scala/org/apache/mxnet/infer/ImageClassifier.scala +++ b/scala-package/infer/src/main/scala/org/apache/mxnet/infer/ImageClassifier.scala @@ -198,7 +198,7 @@ object ImageClassifier { /** * Loads a batch of images from a folder - * @param inputImageDirPath Path to a folder of images + * @param inputImagePaths Path to a folder of images * @return List of buffered images */ def loadInputBatch(inputImagePaths: List[String]): Traversable[BufferedImage] = { From 04cbb491d1db757cc4f8dd76e620801636ad1bb1 Mon Sep 17 00:00:00 2001 From: Qing Date: Tue, 22 May 2018 17:16:47 -0700 Subject: [PATCH 13/20] in Sync with another PR --- .../src/main/scala/org/apache/mxnet/utils/CToScalaUtils.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scala-package/macros/src/main/scala/org/apache/mxnet/utils/CToScalaUtils.scala b/scala-package/macros/src/main/scala/org/apache/mxnet/utils/CToScalaUtils.scala index b7c4fb8560b6..9d51ddcb674a 100644 --- a/scala-package/macros/src/main/scala/org/apache/mxnet/utils/CToScalaUtils.scala +++ b/scala-package/macros/src/main/scala/org/apache/mxnet/utils/CToScalaUtils.scala @@ -32,7 +32,7 @@ private[mxnet] object CToScalaUtils { case "long" | "long(non-negative)" => "Long" case "double" | "doubleorNone" => "Double" case "string" => "String" - case "boolean" => "Boolean" + case "boolean" | "booleanorNone" => "Boolean" case "tupleof" | "tupleof" | "ptr" | "" => "Any" case default => throw new IllegalArgumentException( s"Invalid type for args: $default, $argType") From aec9f91bd75b3039c0362946914cf8d88d2c8ebc Mon Sep 17 00:00:00 2001 From: Qing Date: Tue, 22 May 2018 20:59:36 -0700 Subject: [PATCH 14/20] trigger CI From 2e0bd1265db4d0737e488449ee407b31abac682d Mon Sep 17 00:00:00 2001 From: Qing Date: Wed, 23 May 2018 12:23:37 -0700 Subject: [PATCH 15/20] remove redundant lines and add NDArray Doc support --- .../scala/org/apache/mxnet/NDArrayAPI.scala | 2 +- .../scala/org/apache/mxnet/NDArrayMacro.scala | 63 +------------------ 2 files changed, 4 insertions(+), 61 deletions(-) diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/NDArrayAPI.scala b/scala-package/core/src/main/scala/org/apache/mxnet/NDArrayAPI.scala index d234ac66bdd8..6136db29d1eb 100644 --- a/scala-package/core/src/main/scala/org/apache/mxnet/NDArrayAPI.scala +++ b/scala-package/core/src/main/scala/org/apache/mxnet/NDArrayAPI.scala @@ -20,5 +20,5 @@ package org.apache.mxnet * typesafe NDArray API: NDArray.api._ * Main code will be generated during compile time through Macros */ -object NDArrayAPI { +object NDArrayAPI extends NDArrayAPIBase { } diff --git a/scala-package/macros/src/main/scala/org/apache/mxnet/NDArrayMacro.scala b/scala-package/macros/src/main/scala/org/apache/mxnet/NDArrayMacro.scala index bbe786f5a0af..fe348df033ca 100644 --- a/scala-package/macros/src/main/scala/org/apache/mxnet/NDArrayMacro.scala +++ b/scala-package/macros/src/main/scala/org/apache/mxnet/NDArrayMacro.scala @@ -18,7 +18,7 @@ package org.apache.mxnet import org.apache.mxnet.init.Base._ -import org.apache.mxnet.utils.OperatorBuildUtils +import org.apache.mxnet.utils.{CToScalaUtils, OperatorBuildUtils} import scala.annotation.StaticAnnotation import scala.collection.mutable.ListBuffer @@ -129,7 +129,7 @@ private[mxnet] object NDArrayMacro { // scalastyle:on // Combine and build the function string val returnType = "org.apache.mxnet.NDArray" - var finalStr = s"def ${ndarrayfunction.name}New" + var finalStr = s"def ${ndarrayfunction.name}" finalStr += s" (${argDef.mkString(",")}) : $returnType" finalStr += s" = {${impl.mkString("\n")}}" c.parse(finalStr).asInstanceOf[DefDef] @@ -170,63 +170,6 @@ private[mxnet] object NDArrayMacro { } - // Convert C++ Types to Scala Types - private def typeConversion(in : String, argType : String = "") : String = { - in match { - case "Shape(tuple)" | "ShapeorNone" => "org.apache.mxnet.Shape" - case "Symbol" | "NDArray" | "NDArray-or-Symbol" => "org.apache.mxnet.NDArray" - case "Symbol[]" | "NDArray[]" | "NDArray-or-Symbol[]" | "SymbolorSymbol[]" - => "Array[org.apache.mxnet.NDArray]" - case "float" | "real_t" | "floatorNone" => "org.apache.mxnet.Base.MXFloat" - case "int" | "intorNone" | "int(non-negative)" => "Int" - case "long" | "long(non-negative)" => "Long" - case "double" | "doubleorNone" => "Double" - case "string" => "String" - case "boolean" | "booleanorNone" => "Boolean" - case "tupleof" | "tupleof" | "ptr" | "" => "Any" - case default => throw new IllegalArgumentException( - s"Invalid type for args: $default, $argType") - } - } - - - /** - * By default, the argType come from the C++ API is a description more than a single word - * For Example: - * , , - * The three field shown above do not usually come at the same time - * This function used the above format to determine if the argument is - * optional, what is it Scala type and possibly pass in a default value - * @param argType Raw arguement Type description - * @return (Scala_Type, isOptional) - */ - private def argumentCleaner(argType : String) : (String, Boolean) = { - val spaceRemoved = argType.replaceAll("\\s+", "") - var commaRemoved : Array[String] = new Array[String](0) - // Deal with the case e.g: stype : {'csr', 'default', 'row_sparse'} - if (spaceRemoved.charAt(0)== '{') { - val endIdx = spaceRemoved.indexOf('}') - commaRemoved = spaceRemoved.substring(endIdx + 1).split(",") - commaRemoved(0) = "string" - } else { - commaRemoved = spaceRemoved.split(",") - } - // Optional Field - if (commaRemoved.length >= 3) { - // arg: Type, optional, default = Null - require(commaRemoved(1).equals("optional")) - require(commaRemoved(2).startsWith("default=")) - (typeConversion(commaRemoved(0), argType), true) - } else if (commaRemoved.length == 2 || commaRemoved.length == 1) { - val tempType = typeConversion(commaRemoved(0), argType) - val tempOptional = tempType.equals("org.apache.mxnet.NDArray") - (tempType, tempOptional) - } else { - throw new IllegalArgumentException( - s"Unrecognized arg field: $argType, ${commaRemoved.length}") - } - - } // List and add all the atomic symbol functions to current module. @@ -268,7 +211,7 @@ private[mxnet] object NDArrayMacro { } // scalastyle:on println val argList = argNames zip argTypes map { case (argName, argType) => - val typeAndOption = argumentCleaner(argType) + val typeAndOption = CToScalaUtils.argumentCleaner(argType, "org.apache.mxnet.NDArray") new NDArrayArg(argName, typeAndOption._1, typeAndOption._2) } new NDArrayFunction(aliasName, argList.toList) From bcbac488ea026996d2eced92b272f865b82e6fe7 Mon Sep 17 00:00:00 2001 From: Qing Date: Wed, 30 May 2018 14:56:15 -0700 Subject: [PATCH 16/20] Fix the return type of the NDArray, from NDArray to NDArrayFuncReturn --- .../src/main/scala/org/apache/mxnet/APIDocGenerator.scala | 5 ++++- .../src/main/scala/org/apache/mxnet/NDArrayMacro.scala | 2 +- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/scala-package/macros/src/main/scala/org/apache/mxnet/APIDocGenerator.scala b/scala-package/macros/src/main/scala/org/apache/mxnet/APIDocGenerator.scala index 1630d41917f0..151423611154 100644 --- a/scala-package/macros/src/main/scala/org/apache/mxnet/APIDocGenerator.scala +++ b/scala-package/macros/src/main/scala/org/apache/mxnet/APIDocGenerator.scala @@ -92,11 +92,14 @@ private[mxnet] object APIDocGenerator{ argDef += s"$currArgName : ${absClassArg.argType}" } }) + var returnType = absClassFunc.returnType if (isSymbol) { argDef += "name : String = null" argDef += "attr : Map[String, String] = null" + } else { + returnType = "org.apache.mxnet.NDArrayFuncReturn" } - s"def ${absClassFunc.name} (${argDef.mkString(", ")}) : ${absClassFunc.returnType}" + s"def ${absClassFunc.name} (${argDef.mkString(", ")}) : ${returnType}" } diff --git a/scala-package/macros/src/main/scala/org/apache/mxnet/NDArrayMacro.scala b/scala-package/macros/src/main/scala/org/apache/mxnet/NDArrayMacro.scala index c8fd9759b70d..ce5b532bc8b8 100644 --- a/scala-package/macros/src/main/scala/org/apache/mxnet/NDArrayMacro.scala +++ b/scala-package/macros/src/main/scala/org/apache/mxnet/NDArrayMacro.scala @@ -133,7 +133,7 @@ private[mxnet] object NDArrayMacro { impl += "org.apache.mxnet.NDArray.genericNDArrayFunctionInvoke(\"" + ndarrayfunction.name + "\", null, map.toMap)" // scalastyle:on // Combine and build the function string - val returnType = "org.apache.mxnet.NDArray" + val returnType = "org.apache.mxnet.NDArrayFuncReturn" var finalStr = s"def ${ndarrayfunction.name}" finalStr += s" (${argDef.mkString(",")}) : $returnType" finalStr += s" = {${impl.mkString("\n")}}" From 2086e4066144ec00d6ea82a3521307eb1dd10f64 Mon Sep 17 00:00:00 2001 From: Qing Date: Wed, 30 May 2018 16:40:18 -0700 Subject: [PATCH 17/20] reTrigger CI From a32aa25925c7efbd596c461927ea3b7e89c18f82 Mon Sep 17 00:00:00 2001 From: Qing Date: Wed, 30 May 2018 20:19:15 -0700 Subject: [PATCH 18/20] reTrigger CI From 0c2840ba1cf546d184ff2572003c1940fb836482 Mon Sep 17 00:00:00 2001 From: Naveen Swamy Date: Thu, 31 May 2018 03:16:40 -0700 Subject: [PATCH 19/20] Update APIDocGenerator.scala --- .../scala/org/apache/mxnet/APIDocGenerator.scala | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/scala-package/macros/src/main/scala/org/apache/mxnet/APIDocGenerator.scala b/scala-package/macros/src/main/scala/org/apache/mxnet/APIDocGenerator.scala index 151423611154..d0aaba41baaa 100644 --- a/scala-package/macros/src/main/scala/org/apache/mxnet/APIDocGenerator.scala +++ b/scala-package/macros/src/main/scala/org/apache/mxnet/APIDocGenerator.scala @@ -61,11 +61,11 @@ private[mxnet] object APIDocGenerator{ } // Generate ScalaDoc type - def generateAPIDocFromBackend(traitFunc : absClassFunction) : String = { - val desc = traitFunc.desc.split("\n").map({ currStr => + def generateAPIDocFromBackend(func : absClassFunction) : String = { + val desc = func.desc.split("\n").map({ currStr => s" * $currStr" }) - val params = traitFunc.listOfArgs.map({ absClassArg => + val params = func.listOfArgs.map({ absClassArg => val currArgName = absClassArg.argName match { case "var" => "vari" case "type" => "typeOf" @@ -77,9 +77,9 @@ private[mxnet] object APIDocGenerator{ s" /**\n${desc.mkString("\n")}\n${params.mkString("\n")}\n$returnType\n */" } - def generateAPISignature(absClassFunc : absClassFunction, isSymbol : Boolean) : String = { + def generateAPISignature(func : absClassFunction, isSymbol : Boolean) : String = { var argDef = ListBuffer[String]() - absClassFunc.listOfArgs.foreach(absClassArg => { + func.listOfArgs.foreach(absClassArg => { val currArgName = absClassArg.argName match { case "var" => "vari" case "type" => "typeOf" @@ -92,7 +92,7 @@ private[mxnet] object APIDocGenerator{ argDef += s"$currArgName : ${absClassArg.argType}" } }) - var returnType = absClassFunc.returnType + var returnType = func.returnType if (isSymbol) { argDef += "name : String = null" argDef += "attr : Map[String, String] = null" From 27ce453e9924cb905bf0e46055c0b6433f9da9d7 Mon Sep 17 00:00:00 2001 From: Qing Date: Thu, 31 May 2018 09:42:52 -0700 Subject: [PATCH 20/20] add name changes fix... --- .../src/main/scala/org/apache/mxnet/APIDocGenerator.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/scala-package/macros/src/main/scala/org/apache/mxnet/APIDocGenerator.scala b/scala-package/macros/src/main/scala/org/apache/mxnet/APIDocGenerator.scala index d0aaba41baaa..90fe2604e8b6 100644 --- a/scala-package/macros/src/main/scala/org/apache/mxnet/APIDocGenerator.scala +++ b/scala-package/macros/src/main/scala/org/apache/mxnet/APIDocGenerator.scala @@ -73,7 +73,7 @@ private[mxnet] object APIDocGenerator{ } s" * @param $currArgName\t\t${absClassArg.argDesc}" }) - val returnType = s" * @return ${traitFunc.returnType}" + val returnType = s" * @return ${func.returnType}" s" /**\n${desc.mkString("\n")}\n${params.mkString("\n")}\n$returnType\n */" } @@ -99,7 +99,7 @@ private[mxnet] object APIDocGenerator{ } else { returnType = "org.apache.mxnet.NDArrayFuncReturn" } - s"def ${absClassFunc.name} (${argDef.mkString(", ")}) : ${returnType}" + s"def ${func.name} (${argDef.mkString(", ")}) : ${returnType}" }