[MXNET-357] New Scala API Design (Symbol)#10660
Conversation
| ) | ||
| ) | ||
| ) | ||
| val newFunctionDefs = newSymbolFunctions map { symbolfunction => |
There was a problem hiding this comment.
remove if it is not used currently.
There was a problem hiding this comment.
It will be implemented in the next commit
There was a problem hiding this comment.
I know, but remove this for now.
|
|
||
| private def argumentCleaner(argType : String) : (String, Boolean) = { | ||
| val spaceRemoved = argType.replaceAll("\\s+", "") | ||
| var commaRemoved : Array[String] = new Array[String](0) |
There was a problem hiding this comment.
looks like you can write
val commaRemoved = if ... else ...
There was a problem hiding this comment.
Seemed not applicable as I need to change one of the element in the Array
| // scalastyle:on println | ||
| (aliasName, new SymbolFunction(handle, keyVarNumArgs.value)) | ||
| val argList = argNames zip argTypes map { case (argName, argType) => | ||
| val tup = argumentCleaner(argType) |
| val endIdx = spaceRemoved.indexOf('}') | ||
| commaRemoved = spaceRemoved.substring(endIdx + 1).split(",") | ||
| // commaRemoved(0) = spaceRemoved.substring(0, endIdx+1) | ||
| commaRemoved(0) = "string" |
There was a problem hiding this comment.
could you explain more about the process logic here? I don't quite get the point.
There was a problem hiding this comment.
The input can be in the format:
e.g: stype : {'csr', 'default', 'row_sparse'}
In which case we need to get rid of the "{}" and set the type as string. This part is just a part of the data cleaning
| commaRemoved = spaceRemoved.split(",") | ||
| } | ||
| // Optional Field | ||
| if (commaRemoved.length >= 3) { |
There was a problem hiding this comment.
There are current different format for this, usually:
arg : Type
arg: Type, required
arg: Type, optional, default = Null
The logic here is trying to handle all of these cases
There was a problem hiding this comment.
We'd better make the pattern clear. For example, do assertion
if (commaRemoved.length >= 3) {
// arg: Type, optional, default = Null
require(commaRemoved[1] == "required")
require(commaRemoved[2].startsWith("default = ")so that if some other patterns appear one day, we can fail immediately and have chance to fix it before it leaks to public and causes strange error.
| val opNames = ListBuffer.empty[String] | ||
| _LIB.mxListAllOpNames(opNames) | ||
| opNames.map(opName => { | ||
| opNames.filter(!_.startsWith("_")).map(opName => { |
There was a problem hiding this comment.
The reason for filter _ is to remove the internal function to be compiled. The Documentation for Internal function has not been updated for a long time.
There was a problem hiding this comment.
Can we have a list of those internal functions? Since it is a broken for api compatibility, we need to review whether it is safe to remove.
btw, it removes _contrib_ as well, you don't mean that, right?
There was a problem hiding this comment.
Yeah, not mean't to remove contrib
|
The recent commit including the new API functions, currently we call them "New". To access them, build the package and call <Function_Name>New and you will be able to access. In order to merge this PR, Example is required to prove the API will functioned normally. At least one example will be added in this case |
| val funcName = symbolfunction.name | ||
| val tName = TermName(funcName) | ||
| q""" | ||
| @Deprecated |
There was a problem hiding this comment.
suggest to deprecate them later, when the new api is proved to be stable.
There was a problem hiding this comment.
Will remove it in the next commit
| impl += "createSymbolGeneral(\"" + symbolfunction.name + "\", name, attr, Seq(), map.toMap)" | ||
| // Combine and build the function string | ||
| val returnType = "org.apache.mxnet.Symbol" | ||
| var finalStr = s"def ${symbolfunction.name}New" |
There was a problem hiding this comment.
How about postfix 'Ex' for 'Expand', which is also consistent with those in c_api.h
There was a problem hiding this comment.
Instead of adding postfix, we decided to call the API as Symbol.api.Function name
| if (spaceRemoved.charAt(0)== '{') { | ||
| val endIdx = spaceRemoved.indexOf('}') | ||
| commaRemoved = spaceRemoved.substring(endIdx + 1).split(",") | ||
| commaRemoved(0) = "string" |
There was a problem hiding this comment.
if argType = {'csr', 'default', 'row_sparse'}, then we will do typeConversion("string", "{'csr', 'default', 'row_sparse'}") ?
if so, then these two lines are pretty confusing. why not simply commaRemoved = Array("string")
There was a problem hiding this comment.
Previously I was thinking adding these into the default field, but just found them unnecessary. The reason not doing Array("string") is we are not sure if this arg contains optional field. We need to do a split "," to make sure of that
There was a problem hiding this comment.
Adding more comments here to avoid misunderstanding
| commaRemoved = spaceRemoved.split(",") | ||
| } | ||
| // Optional Field | ||
| if (commaRemoved.length >= 3) { |
There was a problem hiding this comment.
We'd better make the pattern clear. For example, do assertion
if (commaRemoved.length >= 3) {
// arg: Type, optional, default = Null
require(commaRemoved[1] == "required")
require(commaRemoved[2].startsWith("default = ")so that if some other patterns appear one day, we can fail immediately and have chance to fix it before it leaks to public and causes strange error.
nswamy
left a comment
There was a problem hiding this comment.
Great Job! 💯 . This is using advanced features and scala macros are really cryptic, I want to encourage you to add lots of comments so someone else do not have to break their head like you had to.
| */ | ||
| package org.apache.mxnet | ||
| @AddNewSymbolFunctions(false) | ||
| object NewSymbol { |
| private val functions: Map[String, SymbolFunction] = initSymbolModule() | ||
| private val bindReqMap = Map("null" -> 0, "write" -> 1, "add" -> 3) | ||
|
|
||
| val api = NewSymbol |
| private[mxnet] def macroTransform(annottees: Any*) = macro SymbolImplMacros.addDefs | ||
| } | ||
|
|
||
| private[mxnet] class AddNewSymbolFunctions(isContrib: Boolean) extends StaticAnnotation { |
There was a problem hiding this comment.
AddNewSymbolFunctions-> AddSymbolAPIs ?
| * limitations under the License. | ||
| */ | ||
| package org.apache.mxnet | ||
| @AddNewSymbolFunctions(false) |
There was a problem hiding this comment.
AddNewSymbolFunctions->GenerateSymbolAPIs ?
| } | ||
| // scalastyle:off havetype | ||
| def addNewDefs(c: blackbox.Context)(annottees: c.Expr[Any]*) = { | ||
| impl(c)(false, true, annottees: _*) |
There was a problem hiding this comment.
Please create a new method for generating the new APIs
There was a problem hiding this comment.
I think we should keep using impl to avoid duplicated code, since the new API implementation is just a small component in
| argDef += "attr : Map[String, String] = null" | ||
| // Construct Implementation field | ||
| var impl = ListBuffer[String]() | ||
| impl += "val map = scala.collection.mutable.Map[String, Any]()" |
There was a problem hiding this comment.
Can you not do this in the above loop
| private def typeConversion(in : String, argType : String = "") : String = { | ||
| in match { | ||
| case "Shape(tuple)" | "ShapeorNone" => "org.apache.mxnet.Shape" | ||
| case "Symbol" | "NDArray" | "NDArray-or-Symbol" => "org.apache.mxnet.Symbol" |
There was a problem hiding this comment.
Should NDArray also return Symbol?
There was a problem hiding this comment.
Yes, they share the same Function name
| case "double" => "Double" | ||
| case "string" => "String" | ||
| case "boolean" => "Boolean" | ||
| case "tupleof<float>" => "Any" |
There was a problem hiding this comment.
Why Any? Any will remove type checks
| val opNames = ListBuffer.empty[String] | ||
| _LIB.mxListAllOpNames(opNames) | ||
| opNames.map(opName => { | ||
| opNames.filter(op => !op.startsWith("_") || op.startsWith("_contrib_")).map(opName => { |
There was a problem hiding this comment.
We are filtering all operators that start with _ unfortunately the sparse, linear algebra and other operators also start with _. Look at this https://github.com/apache/incubator-mxnet/blob/4fb5241b47c8147690fd6408b55cb694d544656e/python/mxnet/base.py#L455 We need to revisit this again
There was a problem hiding this comment.
Sure, let's add a TODO in here to make sure adding more functions.
|
|
||
|
|
||
|
|
||
| private def argumentCleaner(argType : String) : (String, Boolean) = { |
There was a problem hiding this comment.
Please add comments of what this method is trying to achieve. Also, please add how the structure of the API looks like when extracted from C++ and at the end when you are done cleaning up.
|
Now is turn out to be fun. I removed the underscore filter and adding support for underscore function generation. Please kindly review the code parser section and find if there are possible ways to convert some "Any"s to actual Scala types |
|
In the latest commit, unit test were added for testing Scala API. In order to help init-native/base find the correct path to import the library, a environment variable called BASEDIR were used to help user customize the base directory they need. Unit test specifically focus on the Argument cleaner to make sure it can correctly handle different argument type descriptions |
| var baseDir = System.getProperty("user.dir") + "/init-native" | ||
| if (System.getenv().containsKey("BASEDIR")) { | ||
| baseDir = sys.env("BASEDIR") | ||
| } |
There was a problem hiding this comment.
what are you expecting BASEDIR variable to be?
I think you should have a else
else { baseDir System.getProperty("user.dir") } baseDir = baseDir + "/init-native"
There was a problem hiding this comment.
BASEDIR locate in the pom file (set as an environment variable) which determine the current base directory.
There was a problem hiding this comment.
for system environment var, better to specify it is mxnet-related, e.g., MXNET_SCALA_MACRO_BASEDIR or something like it.
| } | ||
| // scalastyle:off havetype | ||
| def addNewDefs(c: blackbox.Context)(annottees: c.Expr[Any]*) = { | ||
| newAPIImpl(c)(annottees: _*) |
There was a problem hiding this comment.
can we have a more meaningful name, for example, typeSafeAPI?
|
|
||
| // Construct argument field | ||
| var argDef = ListBuffer[String]() | ||
| symbolfunction.listOfArgs.foreach(symbolarg => { |
There was a problem hiding this comment.
can we combine this with the next foreach, i.e., line 120?
| if (spaceRemoved.charAt(0)== '{') { | ||
| val endIdx = spaceRemoved.indexOf('}') | ||
| commaRemoved = spaceRemoved.substring(endIdx + 1).split(",") | ||
| commaRemoved(0) = "string" |
| // Optional Field | ||
| if (commaRemoved.length >= 3) { | ||
| // arg: Type, optional, default = Null | ||
| require(commaRemoved(1).equals("optional")) |
There was a problem hiding this comment.
just remind, better to use == in scala. == behaves the same as equals in Java, and equals in Scala behaves the same as == in Java... sigh... Since String is immutable, equals here is fine.
There was a problem hiding this comment.
Understood, will note this.
|
|
||
| override def afterAll(): Unit = { | ||
|
|
||
| } |
| */ | ||
| package org.apache.mxnet | ||
| @AddSymbolAPIs(false) | ||
| object SymbolAPI { |
There was a problem hiding this comment.
comment a bit for this placeholder.
| @throws(classOf[UnsatisfiedLinkError]) | ||
| private def tryLoadInitLibrary(): Unit = { | ||
| val baseDir = System.getProperty("user.dir") + "/init-native" | ||
| // val baseDir = System.getProperty("user.dir") + "/init-native" |
| var baseDir = System.getProperty("user.dir") + "/init-native" | ||
| if (System.getenv().containsKey("BASEDIR")) { | ||
| baseDir = sys.env("BASEDIR") | ||
| } |
There was a problem hiding this comment.
for system environment var, better to specify it is mxnet-related, e.g., MXNET_SCALA_MACRO_BASEDIR or something like it.
nswamy
left a comment
There was a problem hiding this comment.
minor comments, please address them. we can merge it.
| private def tryLoadInitLibrary(): Unit = { | ||
| val baseDir = System.getProperty("user.dir") + "/init-native" | ||
| var baseDir = System.getProperty("user.dir") + "/init-native" | ||
| if (System.getenv().containsKey("MXNET_SCALA_MACRO_BASEDIR")) { |
There was a problem hiding this comment.
can we change this to MXNET_BASEDIR and we can append the macro dir location relative to it. I don't think the user needs to understand what macros are or find out macro dir.
| <artifactId>scalatest-maven-plugin</artifactId> | ||
| <configuration> | ||
| <environmentVariables> | ||
| <MXNET_SCALA_MACRO_BASEDIR>${project.parent.basedir}/init-native</MXNET_SCALA_MACRO_BASEDIR> |
| if (isContrib) symbolFunctions.filter(_._1.startsWith("_contrib_")) | ||
| else symbolFunctions.filter(!_._1.startsWith("_contrib_")) | ||
| if (isContrib) symbolFunctions.filter( | ||
| func => func.name.startsWith("_contrib_") || !func.name.startsWith("_")) |
There was a problem hiding this comment.
Thinking of how this flag would be used, one possible use-case is we pull all the contrib APIs into a separate Object to make it more explicit as in ?
Symbol.api.foo --> Stable APIs
Symbol.contrib.bar--> Contrib APIs
We can do this as a separate PR, thoughts?
| if (isContrib) funcName.substring(funcName.indexOf("_contrib_") + "_contrib_".length()) | ||
| else funcName | ||
|
|
||
| val functionDefs = newSymbolFunctions map { symbolfunction => |
There was a problem hiding this comment.
I like this code way better than the old, but I think we introducing new way of generating code for old and new APIs, if there is any bug that we haven't caught it will break both old and new APIs. I think its a risk.
|
|
||
| val newSymbolFunctions = { | ||
| if (isContrib) symbolFunctions.filter( | ||
| func => func.name.startsWith("_contrib_") || !func.name.startsWith("_")) |
There was a problem hiding this comment.
same as above, thoughts? can be done as a separate PR
| argDef += "name : String = null" | ||
| argDef += "attr : Map[String, String] = null" | ||
| // scalastyle:off | ||
| impl += "org.apache.mxnet.Symbol.createSymbolGeneral(\"" + symbolfunction.name + "\", name, attr, Seq(), map.toMap)" |
There was a problem hiding this comment.
Can we make a note to Seq() when we deprecate old APIs
Add relative path to MXNET_BASEDIR
| val baseDir = System.getProperty("user.dir") + "/init-native" | ||
| var baseDir = System.getProperty("user.dir") + "/init-native" | ||
| if (System.getenv().containsKey("MXNET_BASEDIR")) { | ||
| baseDir = sys.env("MXNET_BASEDIR") + "/scala-package/init-native" |
There was a problem hiding this comment.
just FYI, I updated this line from
baseDir = sys.env("MXNET_BASEDIR")
to
baseDir = sys.env("MXNET_BASEDIR") + "/scala-package/init-native"
* Simplfied current Macros impl to Quasiquote * Change the Symbol Function Field, add SymbolArg * Fix the Macros problem, disable the hidden function _ * Add Implementation for New API * Add examples and comments * Add _contrib_ support * New namespace for Symbol API * Change names and add comments * add TODOs and name changes * Add relative path to MXNET_BASEDIR * Update Base.scala
* Simplfied current Macros impl to Quasiquote * Change the Symbol Function Field, add SymbolArg * Fix the Macros problem, disable the hidden function _ * Add Implementation for New API * Add examples and comments * Add _contrib_ support * New namespace for Symbol API * Change names and add comments * add TODOs and name changes * Add relative path to MXNET_BASEDIR * Update Base.scala
* Simplfied current Macros impl to Quasiquote * Change the Symbol Function Field, add SymbolArg * Fix the Macros problem, disable the hidden function _ * Add Implementation for New API * Add examples and comments * Add _contrib_ support * New namespace for Symbol API * Change names and add comments * add TODOs and name changes * Add relative path to MXNET_BASEDIR * Update Base.scala
Description
See full design document
@nswamy @yzhliu
This PR is the Addition for new Symbol functions of Scala API
Checklist
Essentials