From 507fb63bf580e32ea3ca35508a849d4105f3fe94 Mon Sep 17 00:00:00 2001 From: Qing Date: Wed, 12 Sep 2018 10:08:01 -0700 Subject: [PATCH 1/9] add initial commit for java NDArray --- .../main/scala/org/apache/mxnet/NDArray.scala | 49 +++++++++++++ .../main/scala/org/apache/mxnet/Shape.scala | 4 ++ .../apache/mxnet/api/java/ArgBuilder.scala | 69 +++++++++++++++++++ 3 files changed, 122 insertions(+) create mode 100644 scala-package/core/src/main/scala/org/apache/mxnet/api/java/ArgBuilder.scala diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/NDArray.scala b/scala-package/core/src/main/scala/org/apache/mxnet/NDArray.scala index 9b6a7dc66540..102233f8eb14 100644 --- a/scala-package/core/src/main/scala/org/apache/mxnet/NDArray.scala +++ b/scala-package/core/src/main/scala/org/apache/mxnet/NDArray.scala @@ -567,6 +567,55 @@ class NDArray private[mxnet](private[mxnet] val handle: NDArrayHandle, NDArrayCollector.collect(this) } + /** + * Java Flavor creating new NDArray + * @param arr + * @param shape + * @param ctx + * @return + */ + def this(arr : Array[Float], shape : Shape, ctx : Context) = { + this(NDArray.newAllocHandle(shape, ctx, delayAlloc = false, Base.MX_REAL_TYPE)) + this.set(arr) + } + + override def toString: String = { + val arr = this.toArray + val shape = this.shape.toArray + buildStringHelper(0, arr.length - 1, arr, shape, 0) + "\n" + } + + /** + * Helper function to create formatted NDArray output + * @param start starting position + * @param end ending position + * @param arr Float array of NDArray + * @param shape shape of NDArray + * @param dim current Dimension level of NDArray + * @return String format of NDArray + */ + private def buildStringHelper(start : Int, end : Int, arr : Array[Float], + shape : Array[Int], dim : Int) : String = { + var result = "" + if (dim != shape.length - 1) { + val length = shape(dim) + val fragment = (end - start + 1) / length + for (num <- 0 to length - 1) { + val output = buildStringHelper(start + fragment * num, start + fragment * (num + 1) - 1, + arr, shape, dim + 1) + result += s"$output\n" + } + result = s"${" " * dim}[\n$result${" " * dim}]" + } else { + var temp = ArrayBuffer[String]() + for (i <- start to end) { + temp += arr(i).toString + } + result = s"${" " * dim}[${temp.mkString(",")}]" + } + result + } + // record arrays who construct this array instance // we use weak reference to prevent gc blocking private[mxnet] val dependencies = mutable.HashMap.empty[Long, WeakReference[NDArray]] diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/Shape.scala b/scala-package/core/src/main/scala/org/apache/mxnet/Shape.scala index 689176217722..c4c1c3167b03 100644 --- a/scala-package/core/src/main/scala/org/apache/mxnet/Shape.scala +++ b/scala-package/core/src/main/scala/org/apache/mxnet/Shape.scala @@ -29,6 +29,10 @@ class Shape(dims: Traversable[Int]) extends Serializable { this(dims.toVector) } + def this(dims: Array[Int]) = { + this(dims.toVector) + } + def apply(dim: Int): Int = shape(dim) def get(dim: Int): Int = apply(dim) def size: Int = shape.size diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/api/java/ArgBuilder.scala b/scala-package/core/src/main/scala/org/apache/mxnet/api/java/ArgBuilder.scala new file mode 100644 index 000000000000..4b478e5fc5dd --- /dev/null +++ b/scala-package/core/src/main/scala/org/apache/mxnet/api/java/ArgBuilder.scala @@ -0,0 +1,69 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mxnet.api.java + +import scala.collection.mutable +import scala.collection.mutable.ListBuffer +import collection.JavaConverters._ + +/** + * This arg Builder is intent to solve Java to Scala conversion + * to take the input such as (arg: Any*) + */ +class ArgBuilder { + private var data = ListBuffer[Any]() + private var map = mutable.Map[String, Any]() + + def addArg(anyRef: AnyRef): ArgBuilder = { + require(map.isEmpty, + "Map is not empty, please do either key-value or positional-arg but not both") + this.data += anyRef.asInstanceOf[Any] + this + } + + def addArg(key : String, value : AnyRef) : ArgBuilder = { + require(data.isEmpty, + "Data is not empty, please do either key-value or positional-arg but not both") + this.map(key) = value.asInstanceOf[Any] + this + } + + def addBatchArgs(list : java.util.List[AnyRef]) : ArgBuilder = { + require(map.isEmpty, + "Map is not empty, please do either key-value or positional-arg but not both") + for (i <- 0 to list.size()) { + this.data += list.get(i) + } + this + } + + def addBatchArgs(arr : Array[AnyRef]) : ArgBuilder = { + require(map.isEmpty, + "Map is not empty, please do either key-value or positional-arg but not both") + arr.foreach(ele => this.data += ele) + this + } + + def buildMap() : Map[String, Any] = { + this.map.toMap + } + + def buildSeq() : Seq[Any] = { + this.data + } +} From 61ed5f648738295f7134f4e307f671513c4943fb Mon Sep 17 00:00:00 2001 From: Qing Date: Wed, 12 Sep 2018 17:07:27 -0700 Subject: [PATCH 2/9] add operator named functions --- .../main/scala/org/apache/mxnet/NDArray.scala | 38 +++++++++++++++++++ 1 file changed, 38 insertions(+) diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/NDArray.scala b/scala-package/core/src/main/scala/org/apache/mxnet/NDArray.scala index 102233f8eb14..dcb01f6a46e0 100644 --- a/scala-package/core/src/main/scala/org/apache/mxnet/NDArray.scala +++ b/scala-package/core/src/main/scala/org/apache/mxnet/NDArray.scala @@ -989,6 +989,44 @@ class NDArray private[mxnet](private[mxnet] val handle: NDArrayHandle, this } + /* Java Compatibility Functions + Function name with underscore means + it is going to do the operator as well as + update itself such as += + */ + def add(other : NDArray) : NDArray = this + other + def add(other : Float) : NDArray = this + other + def _add(other : NDArray) : NDArray = this += other + def _add(other : Float) : NDArray = this += other + def subtract(other : NDArray) : NDArray = this - other + def subtract(other : Float) : NDArray = this - other + def _subtract(other : NDArray) : NDArray = this -= other + def _subtract(other : Float) : NDArray = this -= other + def multiply(other : NDArray) : NDArray = this * other + def multiply(other : Float) : NDArray = this * other + def _multiply(other : NDArray) : NDArray = this *= other + def _multiply(other : Float) : NDArray = this *= other + def div(other : NDArray) : NDArray = this / other + def div(other : Float) : NDArray = this / other + def _div(other : NDArray) : NDArray = this /= other + def _div(other : Float) : NDArray = this /= other + def pow(other : NDArray) : NDArray = this ** other + def pow(other : Float) : NDArray = this ** other + def _pow(other : NDArray) : NDArray = this **= other + def _pow(other : Float) : NDArray = this **= other + def mod(other : NDArray) : NDArray = this % other + def mod(other : Float) : NDArray = this % other + def _mod(other : NDArray) : NDArray = this %= other + def _mod(other : Float) : NDArray = this %= other + def greater(other : NDArray) : NDArray = this > other + def greater(other : Float) : NDArray = this > other + def greaterEqual(other : NDArray) : NDArray = this >= other + def greaterEqual(other : Float) : NDArray = this >= other + def lesser(other : NDArray) : NDArray = this < other + def lesser(other : Float) : NDArray = this < other + def lesserEqual(other : NDArray) : NDArray = this <= other + def lesserEqual(other : Float) : NDArray = this <= other + /** * Return a copied flat java array of current array (row-major). * @return A copy of array content. From 824b1cdcaa6320bf011dc2ca6b0f0221f3bda07e Mon Sep 17 00:00:00 2001 From: Qing Date: Thu, 13 Sep 2018 15:07:22 -0700 Subject: [PATCH 3/9] add support for using new Type-safe api --- .../api/java/Java2ScalaConversionKit.scala | 34 +++++++++++++++++++ .../scala/org/apache/mxnet/NDArrayMacro.scala | 6 ++-- .../apache/mxnet/utils/CToScalaUtils.scala | 10 +++--- 3 files changed, 43 insertions(+), 7 deletions(-) create mode 100644 scala-package/core/src/main/scala/org/apache/mxnet/api/java/Java2ScalaConversionKit.scala diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/api/java/Java2ScalaConversionKit.scala b/scala-package/core/src/main/scala/org/apache/mxnet/api/java/Java2ScalaConversionKit.scala new file mode 100644 index 000000000000..9b2905b4b572 --- /dev/null +++ b/scala-package/core/src/main/scala/org/apache/mxnet/api/java/Java2ScalaConversionKit.scala @@ -0,0 +1,34 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.mxnet.api.java + +/** + * This kit is intend to bring Java user seamless experience + * Using Java to do what Scala have + */ +object Java2ScalaConversionKit { + + /** + * Wrap a object with Option + * element => Option[element] + * @param element the input object + * @tparam A Type param + * @return Option[A] + */ + def Option[A](element : A) : Option[A] = Some(element) + +} diff --git a/scala-package/macros/src/main/scala/org/apache/mxnet/NDArrayMacro.scala b/scala-package/macros/src/main/scala/org/apache/mxnet/NDArrayMacro.scala index 2d3a1c7ec5af..d87308f1ea7d 100644 --- a/scala-package/macros/src/main/scala/org/apache/mxnet/NDArrayMacro.scala +++ b/scala-package/macros/src/main/scala/org/apache/mxnet/NDArrayMacro.scala @@ -133,13 +133,15 @@ private[mxnet] object NDArrayMacro { "map(\"" + ndarrayarg.argName + "\") = " + currArgName } impl.append( - if (ndarrayarg.isOptional) s"if (!$currArgName.isEmpty) $base.get" + if (ndarrayarg.isOptional) { + s"if ($currArgName != null && !$currArgName.isEmpty) $base.get" + } else base ) }) // add default out parameter argDef += "out : Option[NDArray] = None" - impl += "if (!out.isEmpty) map(\"out\") = out.get" + impl += "if (out != null && !out.isEmpty) map(\"out\") = out.get" // scalastyle:off impl += "org.apache.mxnet.NDArray.genericNDArrayFunctionInvoke(\"" + ndarrayfunction.name + "\", args.toSeq, map.toMap)" // scalastyle:on diff --git a/scala-package/macros/src/main/scala/org/apache/mxnet/utils/CToScalaUtils.scala b/scala-package/macros/src/main/scala/org/apache/mxnet/utils/CToScalaUtils.scala index d0ebe5b1d2cb..19a54193cec0 100644 --- a/scala-package/macros/src/main/scala/org/apache/mxnet/utils/CToScalaUtils.scala +++ b/scala-package/macros/src/main/scala/org/apache/mxnet/utils/CToScalaUtils.scala @@ -28,12 +28,12 @@ private[mxnet] object CToScalaUtils { case "Symbol" | "NDArray" | "NDArray-or-Symbol" => returnType case "Symbol[]" | "NDArray[]" | "NDArray-or-Symbol[]" | "SymbolorSymbol[]" => s"Array[$returnType]" - case "float" | "real_t" | "floatorNone" => "org.apache.mxnet.Base.MXFloat" - case "int" | "intorNone" | "int(non-negative)" => "Int" - case "long" | "long(non-negative)" => "Long" - case "double" | "doubleorNone" => "Double" + case "float" | "real_t" | "floatorNone" => "java.lang.Float" + case "int" | "intorNone" | "int(non-negative)" => "java.lang.Integer" + case "long" | "long(non-negative)" => "java.lang.Long" + case "double" | "doubleorNone" => "java.lang.Double" case "string" => "String" - case "boolean" | "booleanorNone" => "Boolean" + case "boolean" | "booleanorNone" => "java.lang.Boolean" case "tupleof" | "tupleof" | "tupleof<>" | "ptr" | "" => "Any" case default => throw new IllegalArgumentException( s"Invalid type for args: $default\nString argType: $argType\nargName: $argName") From 8aabde00413980b1d821c6d488c290430e6a05b7 Mon Sep 17 00:00:00 2001 From: Qing Date: Mon, 17 Sep 2018 09:59:33 -0700 Subject: [PATCH 4/9] fix Scala Macros failure --- .../macros/src/test/scala/org/apache/mxnet/MacrosSuite.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scala-package/macros/src/test/scala/org/apache/mxnet/MacrosSuite.scala b/scala-package/macros/src/test/scala/org/apache/mxnet/MacrosSuite.scala index c3a7c58c1afc..4404b0885d57 100644 --- a/scala-package/macros/src/test/scala/org/apache/mxnet/MacrosSuite.scala +++ b/scala-package/macros/src/test/scala/org/apache/mxnet/MacrosSuite.scala @@ -36,7 +36,7 @@ class MacrosSuite extends FunSuite with BeforeAndAfterAll { ) val output = List( ("org.apache.mxnet.Symbol", true), - ("Int", false), + ("java.lang.Integer", false), ("org.apache.mxnet.Shape", true), ("String", true), ("Any", false) From d74f8812ed2b746125317c1162988831f6b25901 Mon Sep 17 00:00:00 2001 From: Qing Date: Mon, 17 Sep 2018 15:11:12 -0700 Subject: [PATCH 5/9] add Java compatible initialize methods and optimized toString --- .../main/scala/org/apache/mxnet/NDArray.scala | 55 ++++++++++++------- .../main/scala/org/apache/mxnet/Shape.scala | 5 ++ 2 files changed, 39 insertions(+), 21 deletions(-) diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/NDArray.scala b/scala-package/core/src/main/scala/org/apache/mxnet/NDArray.scala index dcb01f6a46e0..639d7d446978 100644 --- a/scala-package/core/src/main/scala/org/apache/mxnet/NDArray.scala +++ b/scala-package/core/src/main/scala/org/apache/mxnet/NDArray.scala @@ -248,6 +248,11 @@ object NDArray extends NDArrayBase { def ones(ctx: Context, shape: Int *): NDArray = ones(Shape(shape: _*), ctx) + // Java compatible conversion methods + def empty(shape: Array[Int]): NDArray = empty(Shape(shape)) + def zeros(shape: Array[Int]): NDArray = zeros(Shape(shape)) + def ones(shape: Array[Int]) : NDArray = ones(Shape(shape)) + /** * Create a new NDArray filled with given value, with specified shape. * @param shape shape of the NDArray. @@ -580,38 +585,44 @@ class NDArray private[mxnet](private[mxnet] val handle: NDArrayHandle, } override def toString: String = { - val arr = this.toArray - val shape = this.shape.toArray - buildStringHelper(0, arr.length - 1, arr, shape, 0) + "\n" + buildStringHelper(this, this.shape.length) + "\n" } /** * Helper function to create formatted NDArray output - * @param start starting position - * @param end ending position - * @param arr Float array of NDArray - * @param shape shape of NDArray - * @param dim current Dimension level of NDArray + * The NDArray will be represented in a reduced version if too large + * @param nd NDArray as the input + * @param totalSpace totalSpace of the lowest dimension * @return String format of NDArray */ - private def buildStringHelper(start : Int, end : Int, arr : Array[Float], - shape : Array[Int], dim : Int) : String = { + private def buildStringHelper(nd : NDArray, totalSpace : Int) : String = { var result = "" - if (dim != shape.length - 1) { - val length = shape(dim) - val fragment = (end - start + 1) / length - for (num <- 0 to length - 1) { - val output = buildStringHelper(start + fragment * num, start + fragment * (num + 1) - 1, - arr, shape, dim + 1) + val THRESHOLD = 100000 // longest NDArray to show in full + val ARRAYTHRESHOLD = 1000 // longest array to show in full + val shape = nd.shape + val space = totalSpace - shape.length + if (shape.length != 1) { + val (length, postfix) = + if (shape.product > THRESHOLD) { + // reduced NDArray + (1, s"\n${" " * (space + 1)}... with length ${shape(0)}\n") + } else { + (shape(0), "") + } + for (num <- 0 until length) { + val output = buildStringHelper(nd.at(num), totalSpace) result += s"$output\n" } - result = s"${" " * dim}[\n$result${" " * dim}]" + result = s"${" " * space}[\n$result${" " * space}$postfix]" } else { - var temp = ArrayBuffer[String]() - for (i <- start to end) { - temp += arr(i).toString + if (shape(0) > ARRAYTHRESHOLD) { + // reduced Array + val front = nd.slice(0, 10) + val back = nd.slice(shape(0) - 10, shape(0) - 1) + result = s"${" " * space}[${front.toArray.mkString(",")} ... ${back.toArray.mkString(",")}]" + } else { + result = s"${" " * space}[${nd.toArray.mkString(",")}]" } - result = s"${" " * dim}[${temp.mkString(",")}]" } result } @@ -1032,6 +1043,8 @@ class NDArray private[mxnet](private[mxnet] val handle: NDArrayHandle, * @return A copy of array content. */ def toArray: Array[Float] = { + require(shape.toArray.product < 1000000, + "NDArray size is too large, consider reducing the dimension") internal.toFloatArray } diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/Shape.scala b/scala-package/core/src/main/scala/org/apache/mxnet/Shape.scala index c4c1c3167b03..772ec8e262c1 100644 --- a/scala-package/core/src/main/scala/org/apache/mxnet/Shape.scala +++ b/scala-package/core/src/main/scala/org/apache/mxnet/Shape.scala @@ -29,6 +29,11 @@ class Shape(dims: Traversable[Int]) extends Serializable { this(dims.toVector) } + /** + * Java compatible constructor + * @param dims Array of Int input + * @return Shape + */ def this(dims: Array[Int]) = { this(dims.toVector) } From ca27685ef35dc2a9e5506ee7e622de409855354b Mon Sep 17 00:00:00 2001 From: Qing Date: Tue, 18 Sep 2018 15:33:02 -0700 Subject: [PATCH 6/9] remove the require field and rename the toString method --- .../core/src/main/scala/org/apache/mxnet/NDArray.scala | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/NDArray.scala b/scala-package/core/src/main/scala/org/apache/mxnet/NDArray.scala index 639d7d446978..b31806d45b1d 100644 --- a/scala-package/core/src/main/scala/org/apache/mxnet/NDArray.scala +++ b/scala-package/core/src/main/scala/org/apache/mxnet/NDArray.scala @@ -584,7 +584,11 @@ class NDArray private[mxnet](private[mxnet] val handle: NDArrayHandle, this.set(arr) } - override def toString: String = { + /** + * Visualize the internal structure of NDArray + * @return String that show the structure + */ + def visualize: String = { buildStringHelper(this, this.shape.length) + "\n" } @@ -1043,8 +1047,6 @@ class NDArray private[mxnet](private[mxnet] val handle: NDArrayHandle, * @return A copy of array content. */ def toArray: Array[Float] = { - require(shape.toArray.product < 1000000, - "NDArray size is too large, consider reducing the dimension") internal.toFloatArray } From 4cb03d676cb7d7384b32fd02c350bb846cba9485 Mon Sep 17 00:00:00 2001 From: Qing Date: Wed, 19 Sep 2018 15:50:06 -0700 Subject: [PATCH 7/9] revert changes on Macros and will apply in different PR --- .../api/java/Java2ScalaConversionKit.scala | 34 ------------------- .../scala/org/apache/mxnet/NDArrayMacro.scala | 6 ++-- .../apache/mxnet/utils/CToScalaUtils.scala | 10 +++--- .../scala/org/apache/mxnet/MacrosSuite.scala | 2 +- 4 files changed, 8 insertions(+), 44 deletions(-) delete mode 100644 scala-package/core/src/main/scala/org/apache/mxnet/api/java/Java2ScalaConversionKit.scala diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/api/java/Java2ScalaConversionKit.scala b/scala-package/core/src/main/scala/org/apache/mxnet/api/java/Java2ScalaConversionKit.scala deleted file mode 100644 index 9b2905b4b572..000000000000 --- a/scala-package/core/src/main/scala/org/apache/mxnet/api/java/Java2ScalaConversionKit.scala +++ /dev/null @@ -1,34 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.mxnet.api.java - -/** - * This kit is intend to bring Java user seamless experience - * Using Java to do what Scala have - */ -object Java2ScalaConversionKit { - - /** - * Wrap a object with Option - * element => Option[element] - * @param element the input object - * @tparam A Type param - * @return Option[A] - */ - def Option[A](element : A) : Option[A] = Some(element) - -} diff --git a/scala-package/macros/src/main/scala/org/apache/mxnet/NDArrayMacro.scala b/scala-package/macros/src/main/scala/org/apache/mxnet/NDArrayMacro.scala index d87308f1ea7d..2d3a1c7ec5af 100644 --- a/scala-package/macros/src/main/scala/org/apache/mxnet/NDArrayMacro.scala +++ b/scala-package/macros/src/main/scala/org/apache/mxnet/NDArrayMacro.scala @@ -133,15 +133,13 @@ private[mxnet] object NDArrayMacro { "map(\"" + ndarrayarg.argName + "\") = " + currArgName } impl.append( - if (ndarrayarg.isOptional) { - s"if ($currArgName != null && !$currArgName.isEmpty) $base.get" - } + if (ndarrayarg.isOptional) s"if (!$currArgName.isEmpty) $base.get" else base ) }) // add default out parameter argDef += "out : Option[NDArray] = None" - impl += "if (out != null && !out.isEmpty) map(\"out\") = out.get" + impl += "if (!out.isEmpty) map(\"out\") = out.get" // scalastyle:off impl += "org.apache.mxnet.NDArray.genericNDArrayFunctionInvoke(\"" + ndarrayfunction.name + "\", args.toSeq, map.toMap)" // scalastyle:on diff --git a/scala-package/macros/src/main/scala/org/apache/mxnet/utils/CToScalaUtils.scala b/scala-package/macros/src/main/scala/org/apache/mxnet/utils/CToScalaUtils.scala index 19a54193cec0..d0ebe5b1d2cb 100644 --- a/scala-package/macros/src/main/scala/org/apache/mxnet/utils/CToScalaUtils.scala +++ b/scala-package/macros/src/main/scala/org/apache/mxnet/utils/CToScalaUtils.scala @@ -28,12 +28,12 @@ private[mxnet] object CToScalaUtils { case "Symbol" | "NDArray" | "NDArray-or-Symbol" => returnType case "Symbol[]" | "NDArray[]" | "NDArray-or-Symbol[]" | "SymbolorSymbol[]" => s"Array[$returnType]" - case "float" | "real_t" | "floatorNone" => "java.lang.Float" - case "int" | "intorNone" | "int(non-negative)" => "java.lang.Integer" - case "long" | "long(non-negative)" => "java.lang.Long" - case "double" | "doubleorNone" => "java.lang.Double" + case "float" | "real_t" | "floatorNone" => "org.apache.mxnet.Base.MXFloat" + case "int" | "intorNone" | "int(non-negative)" => "Int" + case "long" | "long(non-negative)" => "Long" + case "double" | "doubleorNone" => "Double" case "string" => "String" - case "boolean" | "booleanorNone" => "java.lang.Boolean" + case "boolean" | "booleanorNone" => "Boolean" case "tupleof" | "tupleof" | "tupleof<>" | "ptr" | "" => "Any" case default => throw new IllegalArgumentException( s"Invalid type for args: $default\nString argType: $argType\nargName: $argName") diff --git a/scala-package/macros/src/test/scala/org/apache/mxnet/MacrosSuite.scala b/scala-package/macros/src/test/scala/org/apache/mxnet/MacrosSuite.scala index 4404b0885d57..c3a7c58c1afc 100644 --- a/scala-package/macros/src/test/scala/org/apache/mxnet/MacrosSuite.scala +++ b/scala-package/macros/src/test/scala/org/apache/mxnet/MacrosSuite.scala @@ -36,7 +36,7 @@ class MacrosSuite extends FunSuite with BeforeAndAfterAll { ) val output = List( ("org.apache.mxnet.Symbol", true), - ("java.lang.Integer", false), + ("Int", false), ("org.apache.mxnet.Shape", true), ("String", true), ("Any", false) From 494b7b5e6446ed201cc468b2015f5d21319e6fa0 Mon Sep 17 00:00:00 2001 From: Qing Date: Thu, 20 Sep 2018 11:27:08 -0700 Subject: [PATCH 8/9] adding new NDArray test --- .../core/src/test/scala/org/apache/mxnet/NDArraySuite.scala | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/scala-package/core/src/test/scala/org/apache/mxnet/NDArraySuite.scala b/scala-package/core/src/test/scala/org/apache/mxnet/NDArraySuite.scala index 5d88bb39e502..edbc98e5a11a 100644 --- a/scala-package/core/src/test/scala/org/apache/mxnet/NDArraySuite.scala +++ b/scala-package/core/src/test/scala/org/apache/mxnet/NDArraySuite.scala @@ -38,6 +38,11 @@ class NDArraySuite extends FunSuite with BeforeAndAfterAll with Matchers { assert(ndones.toScalar === 1f) } + test("new NDArray") { + val ndarray = new NDArray(Array(1.0f, 2.0f), Shape(1, 2), Context.cpu()) + assert(ndarray.shape == Shape(1, 2)) + } + test ("call toScalar on an ndarray which is not a scalar") { intercept[Exception] { NDArray.zeros(1, 1).toScalar } } From bf8792df76cc5672b63c37cea7f6a8ed47e62780 Mon Sep 17 00:00:00 2001 From: Qing Date: Thu, 27 Sep 2018 10:12:12 -0700 Subject: [PATCH 9/9] remove unrelated method to Java --- .../main/scala/org/apache/mxnet/NDArray.scala | 47 ------------------- 1 file changed, 47 deletions(-) diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/NDArray.scala b/scala-package/core/src/main/scala/org/apache/mxnet/NDArray.scala index b31806d45b1d..dc1273315b24 100644 --- a/scala-package/core/src/main/scala/org/apache/mxnet/NDArray.scala +++ b/scala-package/core/src/main/scala/org/apache/mxnet/NDArray.scala @@ -584,53 +584,6 @@ class NDArray private[mxnet](private[mxnet] val handle: NDArrayHandle, this.set(arr) } - /** - * Visualize the internal structure of NDArray - * @return String that show the structure - */ - def visualize: String = { - buildStringHelper(this, this.shape.length) + "\n" - } - - /** - * Helper function to create formatted NDArray output - * The NDArray will be represented in a reduced version if too large - * @param nd NDArray as the input - * @param totalSpace totalSpace of the lowest dimension - * @return String format of NDArray - */ - private def buildStringHelper(nd : NDArray, totalSpace : Int) : String = { - var result = "" - val THRESHOLD = 100000 // longest NDArray to show in full - val ARRAYTHRESHOLD = 1000 // longest array to show in full - val shape = nd.shape - val space = totalSpace - shape.length - if (shape.length != 1) { - val (length, postfix) = - if (shape.product > THRESHOLD) { - // reduced NDArray - (1, s"\n${" " * (space + 1)}... with length ${shape(0)}\n") - } else { - (shape(0), "") - } - for (num <- 0 until length) { - val output = buildStringHelper(nd.at(num), totalSpace) - result += s"$output\n" - } - result = s"${" " * space}[\n$result${" " * space}$postfix]" - } else { - if (shape(0) > ARRAYTHRESHOLD) { - // reduced Array - val front = nd.slice(0, 10) - val back = nd.slice(shape(0) - 10, shape(0) - 1) - result = s"${" " * space}[${front.toArray.mkString(",")} ... ${back.toArray.mkString(",")}]" - } else { - result = s"${" " * space}[${nd.toArray.mkString(",")}]" - } - } - result - } - // record arrays who construct this array instance // we use weak reference to prevent gc blocking private[mxnet] val dependencies = mutable.HashMap.empty[Long, WeakReference[NDArray]]