From e33aae658249c444f82d6cc22aa83c2fc42fa9a0 Mon Sep 17 00:00:00 2001 From: Qing Date: Thu, 27 Sep 2018 10:11:22 -0700 Subject: [PATCH 01/11] add visualize --- .../main/scala/org/apache/mxnet/NDArray.scala | 46 +++++++++++++++++++ 1 file changed, 46 insertions(+) diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/NDArray.scala b/scala-package/core/src/main/scala/org/apache/mxnet/NDArray.scala index 163ed2682532..38d6445e3d64 100644 --- a/scala-package/core/src/main/scala/org/apache/mxnet/NDArray.scala +++ b/scala-package/core/src/main/scala/org/apache/mxnet/NDArray.scala @@ -763,6 +763,52 @@ class NDArray private[mxnet](private[mxnet] val handle: NDArrayHandle, checkCall(_LIB.mxFloat64NDArraySyncCopyFromCPU(handle, source, source.length)) } + /** + * Visualize the internal structure of NDArray + * @return String that show the structure + */ + def visualize: String = { + buildStringHelper(this, this.shape.length) + "\n" + } + /** + * Helper function to create formatted NDArray output + * The NDArray will be represented in a reduced version if too large + * @param nd NDArray as the input + * @param totalSpace totalSpace of the lowest dimension + * @return String format of NDArray + */ + private def buildStringHelper(nd : NDArray, totalSpace : Int) : String = { + var result = "" + val THRESHOLD = 100000 // longest NDArray to show in full + val ARRAYTHRESHOLD = 1000 // longest array to show in full + val shape = nd.shape + val space = totalSpace - shape.length + if (shape.length != 1) { + val (length, postfix) = + if (shape.product > THRESHOLD) { + // reduced NDArray + (1, s"\n${" " * (space + 1)}... with length ${shape(0)}\n") + } else { + (shape(0), "") + } + for (num <- 0 until length) { + val output = buildStringHelper(nd.at(num), totalSpace) + result += s"$output\n" + } + result = s"${" " * space}[\n$result${" " * space}$postfix]" + } else { + if (shape(0) > ARRAYTHRESHOLD) { + // reduced Array + val front = nd.slice(0, 10) + val back = nd.slice(shape(0) - 10, shape(0) - 1) + result = s"${" " * space}[${front.toArray.mkString(",")} ... ${back.toArray.mkString(",")}]" + } else { + result = s"${" " * space}[${nd.toArray.mkString(",")}]" + } + } + result + } + /** * Return a sliced NDArray that shares memory with current one. * NDArray only support continuous slicing on axis 0 From cdca137101c2c0eed6c8b6784a3f70a187c08897 Mon Sep 17 00:00:00 2001 From: Qing Date: Thu, 27 Sep 2018 16:06:00 -0700 Subject: [PATCH 02/11] adding Any type input to form NDArray --- .../main/scala/org/apache/mxnet/NDArray.scala | 60 +++++++++++++++++++ 1 file changed, 60 insertions(+) diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/NDArray.scala b/scala-package/core/src/main/scala/org/apache/mxnet/NDArray.scala index 38d6445e3d64..d524ac8adcbc 100644 --- a/scala-package/core/src/main/scala/org/apache/mxnet/NDArray.scala +++ b/scala-package/core/src/main/scala/org/apache/mxnet/NDArray.scala @@ -200,6 +200,15 @@ object NDArray extends NDArrayBase { "_onehot_encode", Seq(indices, out), Map("out" -> out))(0) } + /** + * Get the String representation of NDArray + * @param nd input NDArray + * @return String + */ + def toString(nd : NDArray) : String = { + nd.visualize + } + /** * Create an empty uninitialized new NDArray, with specified shape. * @@ -509,6 +518,57 @@ object NDArray extends NDArrayBase { array(sourceArr, shape, null) } + /** + * Create a new NDArray based on the structure of source Array + * @param sourceArr Array[Array...Array[Float]...] + * @param ctx context like to pass in + * @return an NDArray with the same shape of the input + */ + def toNDArray(sourceArr: Array[_], ctx : Context = null) : NDArray = { + val shape = ArrayBuffer[Int]() + shapeGetter(sourceArr, shape, 0) + val finalArr = new Array[Float](shape.product) + arrayCombiner(sourceArr, finalArr, 0, finalArr.length - 1) + array(finalArr, Shape(shape), ctx) + } + + private def shapeGetter(sourceArr : Any, + shape : ArrayBuffer[Int], shapeIdx : Int) : Unit = { + sourceArr match { + case arrFloat : Array[Float] => { + val arrLength = arrFloat.length + if (shape.length == shapeIdx) { + shape += arrLength + } + require(shape(shapeIdx) == arrLength, "Each Array should have equal length") + } + case arr : Array[Any] => { + val arrLength = arr.length + if (shape.length == shapeIdx) { + shape += arrLength + } + require(shape(shapeIdx) == arrLength, + s"Each Array should have equal length, expected ${shape(shapeIdx)}, get $arrLength") + arr.foreach(ele => shapeGetter(ele, shape, shapeIdx + 1)) + } + case _ => throw new IllegalArgumentException(s"Wrong type passed: ${sourceArr.getClass}") + } + } + + private def arrayCombiner(sourceArr : Any, arr : Array[Float], start : Int, end : Int) : Unit = { + sourceArr match { + case arrFloat : Array[Float] => { + for (i <- arrFloat.indices) arr(start + i) = arrFloat(i) + } + case arrAny : Array[Any] => { + val fragment = (end - start + 1) / arrAny.length + for (i <- arrAny.indices) + arrayCombiner(arrAny(i), arr, start + i * fragment, end + (i + 1) * fragment) + } + case _ => throw new IllegalArgumentException(s"Wrong type passed: ${sourceArr.getClass}") + } + } + /** * Returns evenly spaced values within a given interval. * Values are generated within the half-open interval [`start`, `stop`). In other From 900e42fcb52415dffdad402b31d7020ff9d84512 Mon Sep 17 00:00:00 2001 From: Qing Date: Tue, 2 Oct 2018 11:26:26 -0700 Subject: [PATCH 03/11] fix bug and add tests --- .../main/scala/org/apache/mxnet/NDArray.scala | 2 +- .../scala/org/apache/mxnet/NDArraySuite.scala | 26 +++++++++++++++++++ 2 files changed, 27 insertions(+), 1 deletion(-) diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/NDArray.scala b/scala-package/core/src/main/scala/org/apache/mxnet/NDArray.scala index d524ac8adcbc..cfe05e0abf37 100644 --- a/scala-package/core/src/main/scala/org/apache/mxnet/NDArray.scala +++ b/scala-package/core/src/main/scala/org/apache/mxnet/NDArray.scala @@ -563,7 +563,7 @@ object NDArray extends NDArrayBase { case arrAny : Array[Any] => { val fragment = (end - start + 1) / arrAny.length for (i <- arrAny.indices) - arrayCombiner(arrAny(i), arr, start + i * fragment, end + (i + 1) * fragment) + arrayCombiner(arrAny(i), arr, start + i * fragment, start + (i + 1) * fragment) } case _ => throw new IllegalArgumentException(s"Wrong type passed: ${sourceArr.getClass}") } diff --git a/scala-package/core/src/test/scala/org/apache/mxnet/NDArraySuite.scala b/scala-package/core/src/test/scala/org/apache/mxnet/NDArraySuite.scala index bc7a0a026bc3..8f084ee40f37 100644 --- a/scala-package/core/src/test/scala/org/apache/mxnet/NDArraySuite.scala +++ b/scala-package/core/src/test/scala/org/apache/mxnet/NDArraySuite.scala @@ -22,6 +22,7 @@ import java.util.concurrent.atomic.AtomicInteger import org.apache.mxnet.NDArrayConversions._ import org.scalatest.{BeforeAndAfterAll, FunSuite, Matchers} +import scala.collection.mutable.ArrayBuffer class NDArraySuite extends FunSuite with BeforeAndAfterAll with Matchers { private val sequence: AtomicInteger = new AtomicInteger(0) @@ -85,6 +86,31 @@ class NDArraySuite extends FunSuite with BeforeAndAfterAll with Matchers { assert(ndarray.toArray === Array(1f, 2f, 3f, 4f)) } + test("create NDArray based on Java Matrix") { + val arrBuf = ArrayBuffer[Array[Float]]() + for (i <- 0 until 100) arrBuf += Array(1.0f, 1.0f, 1.0f, 1.0f) + val arr = Array( + Array( + arrBuf.toArray + ), + Array( + arrBuf.toArray + ) + ) + var nd = NDArray.toNDArray(arr) + require(nd.shape == Shape(2, 1, 100, 4)) + val arr2 = Array(1.0f, 1.0f, 1.0f, 1.0f) + nd = NDArray.toNDArray(arr2) + require(nd.shape == Shape(4)) + } + + test("test Visualize") { + var nd = NDArray.ones(Shape(1, 2, 100, 1)) + nd.visualize + nd = NDArray.ones(Shape(1, 4)) + nd.visualize + } + test("plus") { var ndzeros = NDArray.zeros(2, 1) var ndones = ndzeros + 1f From d4489ff7fa53685f90b716b73f8ea65f6b4cd8f0 Mon Sep 17 00:00:00 2001 From: Qing Date: Tue, 23 Oct 2018 13:22:49 -0700 Subject: [PATCH 04/11] add a toString method --- .../core/src/main/scala/org/apache/mxnet/NDArray.scala | 3 +++ 1 file changed, 3 insertions(+) diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/NDArray.scala b/scala-package/core/src/main/scala/org/apache/mxnet/NDArray.scala index cfe05e0abf37..8ae20ad38cd3 100644 --- a/scala-package/core/src/main/scala/org/apache/mxnet/NDArray.scala +++ b/scala-package/core/src/main/scala/org/apache/mxnet/NDArray.scala @@ -821,6 +821,9 @@ class NDArray private[mxnet](private[mxnet] val handle: NDArrayHandle, require(source.length == size, s"array size (${source.length}) do not match the size of NDArray ($size)") checkCall(_LIB.mxFloat64NDArraySyncCopyFromCPU(handle, source, source.length)) + + override def toString() : String = { + s"${this.visualize}" } /** From 088b85d6b5b23503874e04620d2bdd20aa09d390 Mon Sep 17 00:00:00 2001 From: Qing Date: Tue, 23 Oct 2018 18:10:11 -0700 Subject: [PATCH 05/11] add Visualize Util and migrate visualize structure to there --- .../main/scala/org/apache/mxnet/NDArray.scala | 57 +-------------- .../org/apache/mxnet/util/Visualize.scala | 73 +++++++++++++++++++ 2 files changed, 74 insertions(+), 56 deletions(-) create mode 100644 scala-package/core/src/main/scala/org/apache/mxnet/util/Visualize.scala diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/NDArray.scala b/scala-package/core/src/main/scala/org/apache/mxnet/NDArray.scala index 8ae20ad38cd3..c8857a3c9365 100644 --- a/scala-package/core/src/main/scala/org/apache/mxnet/NDArray.scala +++ b/scala-package/core/src/main/scala/org/apache/mxnet/NDArray.scala @@ -200,15 +200,6 @@ object NDArray extends NDArrayBase { "_onehot_encode", Seq(indices, out), Map("out" -> out))(0) } - /** - * Get the String representation of NDArray - * @param nd input NDArray - * @return String - */ - def toString(nd : NDArray) : String = { - nd.visualize - } - /** * Create an empty uninitialized new NDArray, with specified shape. * @@ -823,53 +814,7 @@ class NDArray private[mxnet](private[mxnet] val handle: NDArrayHandle, checkCall(_LIB.mxFloat64NDArraySyncCopyFromCPU(handle, source, source.length)) override def toString() : String = { - s"${this.visualize}" - } - - /** - * Visualize the internal structure of NDArray - * @return String that show the structure - */ - def visualize: String = { - buildStringHelper(this, this.shape.length) + "\n" - } - /** - * Helper function to create formatted NDArray output - * The NDArray will be represented in a reduced version if too large - * @param nd NDArray as the input - * @param totalSpace totalSpace of the lowest dimension - * @return String format of NDArray - */ - private def buildStringHelper(nd : NDArray, totalSpace : Int) : String = { - var result = "" - val THRESHOLD = 100000 // longest NDArray to show in full - val ARRAYTHRESHOLD = 1000 // longest array to show in full - val shape = nd.shape - val space = totalSpace - shape.length - if (shape.length != 1) { - val (length, postfix) = - if (shape.product > THRESHOLD) { - // reduced NDArray - (1, s"\n${" " * (space + 1)}... with length ${shape(0)}\n") - } else { - (shape(0), "") - } - for (num <- 0 until length) { - val output = buildStringHelper(nd.at(num), totalSpace) - result += s"$output\n" - } - result = s"${" " * space}[\n$result${" " * space}$postfix]" - } else { - if (shape(0) > ARRAYTHRESHOLD) { - // reduced Array - val front = nd.slice(0, 10) - val back = nd.slice(shape(0) - 10, shape(0) - 1) - result = s"${" " * space}[${front.toArray.mkString(",")} ... ${back.toArray.mkString(",")}]" - } else { - result = s"${" " * space}[${nd.toArray.mkString(",")}]" - } - } - result + s"" } /** diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/util/Visualize.scala b/scala-package/core/src/main/scala/org/apache/mxnet/util/Visualize.scala new file mode 100644 index 000000000000..0cafc12a7f8b --- /dev/null +++ b/scala-package/core/src/main/scala/org/apache/mxnet/util/Visualize.scala @@ -0,0 +1,73 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mxnet.util + +import org.apache.mxnet.{NDArray, Shape} + +/** + * A visualize helper class to see the internal structure + * of mxnet data-structure + */ +object Visualize { + + /** + * Visualize the internal structure of NDArray + * @return String that show the structure + */ + def toString(nd : NDArray): String = { + buildStringHelper(nd, nd.shape.length) + "\n" + } + /** + * Helper function to create formatted NDArray output + * The NDArray will be represented in a reduced version if too large + * @param nd NDArray as the input + * @param totalSpace totalSpace of the lowest dimension + * @return String format of NDArray + */ + private def buildStringHelper(nd : NDArray, totalSpace : Int) : String = { + var result = "" + val THRESHOLD = 100000 // longest NDArray to show in full + val ARRAYTHRESHOLD = 1000 // longest array to show in full + val shape = nd.shape + val space = totalSpace - shape.length + if (shape.length != 1) { + val (length, postfix) = + if (shape.product > THRESHOLD) { + // reduced NDArray + (1, s"\n${" " * (space + 1)}... with length ${shape(0)}\n") + } else { + (shape(0), "") + } + for (num <- 0 until length) { + val output = buildStringHelper(nd.at(num), totalSpace) + result += s"$output\n" + } + result = s"${" " * space}[\n$result${" " * space}$postfix]" + } else { + if (shape(0) > ARRAYTHRESHOLD) { + // reduced Array + val front = nd.slice(0, 10) + val back = nd.slice(shape(0) - 10, shape(0) - 1) + result = s"${" " * space}[${front.toArray.mkString(",")} ... ${back.toArray.mkString(",")}]" + } else { + result = s"${" " * space}[${nd.toArray.mkString(",")}]" + } + } + result + } +} From 48b19e8cb0f7e8dd57d6d4c750d195f4d1134998 Mon Sep 17 00:00:00 2001 From: Qing Date: Tue, 23 Oct 2018 18:15:59 -0700 Subject: [PATCH 06/11] update with tests --- .../core/src/test/scala/org/apache/mxnet/NDArraySuite.scala | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/scala-package/core/src/test/scala/org/apache/mxnet/NDArraySuite.scala b/scala-package/core/src/test/scala/org/apache/mxnet/NDArraySuite.scala index 8f084ee40f37..fd0d2021f44f 100644 --- a/scala-package/core/src/test/scala/org/apache/mxnet/NDArraySuite.scala +++ b/scala-package/core/src/test/scala/org/apache/mxnet/NDArraySuite.scala @@ -21,6 +21,7 @@ import java.io.File import java.util.concurrent.atomic.AtomicInteger import org.apache.mxnet.NDArrayConversions._ +import org.apache.mxnet.util.Visualize import org.scalatest.{BeforeAndAfterAll, FunSuite, Matchers} import scala.collection.mutable.ArrayBuffer @@ -106,9 +107,9 @@ class NDArraySuite extends FunSuite with BeforeAndAfterAll with Matchers { test("test Visualize") { var nd = NDArray.ones(Shape(1, 2, 100, 1)) - nd.visualize + Visualize.toString(nd) nd = NDArray.ones(Shape(1, 4)) - nd.visualize + Visualize.toString(nd) } test("plus") { From 7c0a56708b5ed2b218914ce6d9dc0bae075ffbce Mon Sep 17 00:00:00 2001 From: Qing Date: Thu, 10 Jan 2019 14:57:29 -0800 Subject: [PATCH 07/11] refactor code --- .../main/scala/org/apache/mxnet/NDArray.scala | 61 ++++++++++++++-- .../org/apache/mxnet/util/Visualize.scala | 73 ------------------- .../scala/org/apache/mxnet/NDArraySuite.scala | 10 ++- 3 files changed, 60 insertions(+), 84 deletions(-) delete mode 100644 scala-package/core/src/main/scala/org/apache/mxnet/util/Visualize.scala diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/NDArray.scala b/scala-package/core/src/main/scala/org/apache/mxnet/NDArray.scala index c8857a3c9365..fcd125fabfbf 100644 --- a/scala-package/core/src/main/scala/org/apache/mxnet/NDArray.scala +++ b/scala-package/core/src/main/scala/org/apache/mxnet/NDArray.scala @@ -28,6 +28,7 @@ import scala.collection.mutable import scala.collection.mutable.{ArrayBuffer, ListBuffer} import scala.language.implicitConversions import scala.ref.WeakReference +import scala.util.Try /** * NDArray Object extends from NDArrayBase for abstract function signatures @@ -718,7 +719,6 @@ object NDArray extends NDArrayBase { genericNDArrayFunctionInvoke("_crop_assign", args, kwargs) } - // TODO: imdecode } /** @@ -745,6 +745,12 @@ class NDArray private[mxnet](private[mxnet] val handle: NDArrayHandle, // we use weak reference to prevent gc blocking private[mxnet] val dependencies = mutable.HashMap.empty[Long, WeakReference[NDArray]] + private val traceProperty = "mxnet.setNDArrayPrintLength" + private lazy val printLength = { + val value = Try(System.getProperty(traceProperty).toInt).getOrElse(1000) + value + } + def serialize(): Array[Byte] = { val buf = ArrayBuffer.empty[Byte] checkCall(_LIB.mxNDArraySaveRawBytes(handle, buf)) @@ -808,13 +814,54 @@ class NDArray private[mxnet](private[mxnet] val handle: NDArrayHandle, checkCall(_LIB.mxNDArraySyncCopyFromCPU(handle, source, source.length)) } - private def syncCopyfrom(source: Array[Double]): Unit = { - require(source.length == size, - s"array size (${source.length}) do not match the size of NDArray ($size)") - checkCall(_LIB.mxFloat64NDArraySyncCopyFromCPU(handle, source, source.length)) + /** + * Visualize the internal structure of NDArray + * @return String that show the structure + */ + override def toString: String = { + val abstractND = buildStringHelper(this, this.shape.length) + val otherInfo = s"" + s"$abstractND\n$otherInfo" + } - override def toString() : String = { - s"" + /** + * Helper function to create formatted NDArray output + * The NDArray will be represented in a reduced version if too large + * @param nd NDArray as the input + * @param totalSpace totalSpace of the lowest dimension + * @return String format of NDArray + */ + private def buildStringHelper(nd : NDArray, totalSpace : Int) : String = { + var result = "" + val THRESHOLD = 10 // longest NDArray[NDArray[...]] to show in full + val ARRAYTHRESHOLD = printLength // longest array to show in full + val shape = nd.shape + val space = totalSpace - shape.length + if (shape.length != 1) { + val (length, postfix) = + if (shape(0) > THRESHOLD) { + // reduced NDArray + (10, s"\n${" " * (space + 1)}... with length ${shape(0)}\n") + } else { + (shape(0), "") + } + for (num <- 0 until length) { + val output = buildStringHelper(nd.at(num), totalSpace) + result += s"$output\n" + } + result = s"${" " * space}[\n$result${" " * space}$postfix${" " * space}]" + } else { + if (shape(0) > ARRAYTHRESHOLD) { + // reduced Array + val front = nd.slice(0, 10) + val back = nd.slice(shape(0) - 10, shape(0) - 1) + result = s"""${" " * space}[${front.toArray.mkString(",")} + | ... ${back.toArray.mkString(",")}]""".stripMargin + } else { + result = s"${" " * space}[${nd.toArray.mkString(",")}]" + } + } + result } /** diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/util/Visualize.scala b/scala-package/core/src/main/scala/org/apache/mxnet/util/Visualize.scala deleted file mode 100644 index 0cafc12a7f8b..000000000000 --- a/scala-package/core/src/main/scala/org/apache/mxnet/util/Visualize.scala +++ /dev/null @@ -1,73 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.mxnet.util - -import org.apache.mxnet.{NDArray, Shape} - -/** - * A visualize helper class to see the internal structure - * of mxnet data-structure - */ -object Visualize { - - /** - * Visualize the internal structure of NDArray - * @return String that show the structure - */ - def toString(nd : NDArray): String = { - buildStringHelper(nd, nd.shape.length) + "\n" - } - /** - * Helper function to create formatted NDArray output - * The NDArray will be represented in a reduced version if too large - * @param nd NDArray as the input - * @param totalSpace totalSpace of the lowest dimension - * @return String format of NDArray - */ - private def buildStringHelper(nd : NDArray, totalSpace : Int) : String = { - var result = "" - val THRESHOLD = 100000 // longest NDArray to show in full - val ARRAYTHRESHOLD = 1000 // longest array to show in full - val shape = nd.shape - val space = totalSpace - shape.length - if (shape.length != 1) { - val (length, postfix) = - if (shape.product > THRESHOLD) { - // reduced NDArray - (1, s"\n${" " * (space + 1)}... with length ${shape(0)}\n") - } else { - (shape(0), "") - } - for (num <- 0 until length) { - val output = buildStringHelper(nd.at(num), totalSpace) - result += s"$output\n" - } - result = s"${" " * space}[\n$result${" " * space}$postfix]" - } else { - if (shape(0) > ARRAYTHRESHOLD) { - // reduced Array - val front = nd.slice(0, 10) - val back = nd.slice(shape(0) - 10, shape(0) - 1) - result = s"${" " * space}[${front.toArray.mkString(",")} ... ${back.toArray.mkString(",")}]" - } else { - result = s"${" " * space}[${nd.toArray.mkString(",")}]" - } - } - result - } -} diff --git a/scala-package/core/src/test/scala/org/apache/mxnet/NDArraySuite.scala b/scala-package/core/src/test/scala/org/apache/mxnet/NDArraySuite.scala index fd0d2021f44f..0632c25e99b6 100644 --- a/scala-package/core/src/test/scala/org/apache/mxnet/NDArraySuite.scala +++ b/scala-package/core/src/test/scala/org/apache/mxnet/NDArraySuite.scala @@ -21,13 +21,15 @@ import java.io.File import java.util.concurrent.atomic.AtomicInteger import org.apache.mxnet.NDArrayConversions._ -import org.apache.mxnet.util.Visualize import org.scalatest.{BeforeAndAfterAll, FunSuite, Matchers} +import org.slf4j.LoggerFactory import scala.collection.mutable.ArrayBuffer class NDArraySuite extends FunSuite with BeforeAndAfterAll with Matchers { private val sequence: AtomicInteger = new AtomicInteger(0) + private val logger = LoggerFactory.getLogger(classOf[NDArraySuite]) + test("to java array") { val ndarray = NDArray.zeros(2, 2) assert(ndarray.toArray === Array(0f, 0f, 0f, 0f)) @@ -106,10 +108,10 @@ class NDArraySuite extends FunSuite with BeforeAndAfterAll with Matchers { } test("test Visualize") { - var nd = NDArray.ones(Shape(1, 2, 100, 1)) - Visualize.toString(nd) + var nd = NDArray.ones(Shape(1, 2, 1000, 1)) + logger.info(s"Test print large ndarray:\n$nd") nd = NDArray.ones(Shape(1, 4)) - Visualize.toString(nd) + logger.info(s"Test print small ndarray:\n$nd") } test("plus") { From a563d1822e60c00aafecedd3383cd36a3aa1a24a Mon Sep 17 00:00:00 2001 From: Qing Date: Thu, 10 Jan 2019 15:29:02 -0800 Subject: [PATCH 08/11] fix the minor issue --- .../core/src/main/scala/org/apache/mxnet/NDArray.scala | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/NDArray.scala b/scala-package/core/src/main/scala/org/apache/mxnet/NDArray.scala index fcd125fabfbf..6857f92aece6 100644 --- a/scala-package/core/src/main/scala/org/apache/mxnet/NDArray.scala +++ b/scala-package/core/src/main/scala/org/apache/mxnet/NDArray.scala @@ -814,6 +814,12 @@ class NDArray private[mxnet](private[mxnet] val handle: NDArrayHandle, checkCall(_LIB.mxNDArraySyncCopyFromCPU(handle, source, source.length)) } + private def syncCopyfrom(source: Array[Double]): Unit = { + require(source.length == size, + s"array size (${source.length}) do not match the size of NDArray ($size)") + checkCall(_LIB.mxFloat64NDArraySyncCopyFromCPU(handle, source, source.length)) + } + /** * Visualize the internal structure of NDArray * @return String that show the structure From a6a3b4f97416f7d82fcb7009eb58be55138fd496 Mon Sep 17 00:00:00 2001 From: Qing Date: Tue, 15 Jan 2019 18:45:31 -0800 Subject: [PATCH 09/11] add multiple types support --- .../org/apache/mxnet/MX_PRIMITIVES.scala | 6 +++ .../main/scala/org/apache/mxnet/NDArray.scala | 33 ++++++++-------- .../scala/org/apache/mxnet/NDArraySuite.scala | 38 +++++++++++++------ 3 files changed, 51 insertions(+), 26 deletions(-) diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/MX_PRIMITIVES.scala b/scala-package/core/src/main/scala/org/apache/mxnet/MX_PRIMITIVES.scala index cb978856963c..2b92d6e4096a 100644 --- a/scala-package/core/src/main/scala/org/apache/mxnet/MX_PRIMITIVES.scala +++ b/scala-package/core/src/main/scala/org/apache/mxnet/MX_PRIMITIVES.scala @@ -82,4 +82,10 @@ object MX_PRIMITIVES { implicit def MX_DoubleToDouble(d: MX_Double) : Double = d.data + def isValidType(num : Any) : Boolean = { + num match { + case valid @ (_: Float | _: Double) => true + case _ => false + } + } } diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/NDArray.scala b/scala-package/core/src/main/scala/org/apache/mxnet/NDArray.scala index 6857f92aece6..e9007cf10b3a 100644 --- a/scala-package/core/src/main/scala/org/apache/mxnet/NDArray.scala +++ b/scala-package/core/src/main/scala/org/apache/mxnet/NDArray.scala @@ -512,29 +512,34 @@ object NDArray extends NDArrayBase { /** * Create a new NDArray based on the structure of source Array - * @param sourceArr Array[Array...Array[Float]...] + * @param sourceArr Array[Array...Array[MX_PRIMITIVE_TYPE]...] * @param ctx context like to pass in * @return an NDArray with the same shape of the input */ def toNDArray(sourceArr: Array[_], ctx : Context = null) : NDArray = { val shape = ArrayBuffer[Int]() shapeGetter(sourceArr, shape, 0) - val finalArr = new Array[Float](shape.product) - arrayCombiner(sourceArr, finalArr, 0, finalArr.length - 1) - array(finalArr, Shape(shape), ctx) + val container = new Array[Any](shape.product) + arrayCombiner(sourceArr, container, 0, container.length - 1) + val finalArr = container(0) match { + case f: Float => array(container.map(_.asInstanceOf[Float]), Shape(shape), ctx) + case d: Double => array(container.map(_.asInstanceOf[Double]), Shape(shape), ctx) + case _ => throw new IllegalArgumentException(s"Unsupported type ${container(0).getClass}") + } + finalArr } private def shapeGetter(sourceArr : Any, shape : ArrayBuffer[Int], shapeIdx : Int) : Unit = { sourceArr match { - case arrFloat : Array[Float] => { - val arrLength = arrFloat.length + case arr: Array[_] if MX_PRIMITIVES.isValidType(arr(0)) => { + val arrLength = arr.length if (shape.length == shapeIdx) { shape += arrLength } require(shape(shapeIdx) == arrLength, "Each Array should have equal length") } - case arr : Array[Any] => { + case arr: Array[_] => { val arrLength = arr.length if (shape.length == shapeIdx) { shape += arrLength @@ -547,12 +552,13 @@ object NDArray extends NDArrayBase { } } - private def arrayCombiner(sourceArr : Any, arr : Array[Float], start : Int, end : Int) : Unit = { + private def arrayCombiner(sourceArr : Any, arr : Array[Any], + start : Int, end : Int) : Unit = { sourceArr match { - case arrFloat : Array[Float] => { - for (i <- arrFloat.indices) arr(start + i) = arrFloat(i) + case arrValid: Array[_] if MX_PRIMITIVES.isValidType(arrValid(0)) => { + for (i <- arrValid.indices) arr(start + i) = arrValid(i) } - case arrAny : Array[Any] => { + case arrAny: Array[_] => { val fragment = (end - start + 1) / arrAny.length for (i <- arrAny.indices) arrayCombiner(arrAny(i), arr, start + i * fragment, start + (i + 1) * fragment) @@ -746,10 +752,7 @@ class NDArray private[mxnet](private[mxnet] val handle: NDArrayHandle, private[mxnet] val dependencies = mutable.HashMap.empty[Long, WeakReference[NDArray]] private val traceProperty = "mxnet.setNDArrayPrintLength" - private lazy val printLength = { - val value = Try(System.getProperty(traceProperty).toInt).getOrElse(1000) - value - } + private lazy val printLength = Try(System.getProperty(traceProperty).toInt).getOrElse(1000) def serialize(): Array[Byte] = { val buf = ArrayBuffer.empty[Byte] diff --git a/scala-package/core/src/test/scala/org/apache/mxnet/NDArraySuite.scala b/scala-package/core/src/test/scala/org/apache/mxnet/NDArraySuite.scala index 0632c25e99b6..1974bc8b9d89 100644 --- a/scala-package/core/src/test/scala/org/apache/mxnet/NDArraySuite.scala +++ b/scala-package/core/src/test/scala/org/apache/mxnet/NDArraySuite.scala @@ -90,28 +90,44 @@ class NDArraySuite extends FunSuite with BeforeAndAfterAll with Matchers { } test("create NDArray based on Java Matrix") { - val arrBuf = ArrayBuffer[Array[Float]]() - for (i <- 0 until 100) arrBuf += Array(1.0f, 1.0f, 1.0f, 1.0f) - val arr = Array( + def arrayGen(num : Any) : Array[Any] = { + val arrayBuf = num match { + case f: Float => + val arr = ArrayBuffer[Array[Float]]() + for (_ <- 0 until 100) arr += Array(1.0f, 1.0f, 1.0f, 1.0f) + arr + case d: Double => + val arr = ArrayBuffer[Array[Double]]() + for (_ <- 0 until 100) arr += Array(1.0d, 1.0d, 1.0d, 1.0d) + arr + case _ => throw new IllegalArgumentException(s"Unsupported Type ${num.getClass}") + } Array( - arrBuf.toArray - ), - Array( - arrBuf.toArray + Array( + arrayBuf.toArray + ), + Array( + arrayBuf.toArray ) - ) - var nd = NDArray.toNDArray(arr) + ) + } + val floatData = 1.0f + var nd = NDArray.toNDArray(arrayGen(floatData)) require(nd.shape == Shape(2, 1, 100, 4)) val arr2 = Array(1.0f, 1.0f, 1.0f, 1.0f) nd = NDArray.toNDArray(arr2) require(nd.shape == Shape(4)) + val doubleData = 1.0d + nd = NDArray.toNDArray(arrayGen(doubleData)) + require(nd.shape == Shape(2, 1, 100, 4)) + require(nd.dtype == DType.Float64) } test("test Visualize") { var nd = NDArray.ones(Shape(1, 2, 1000, 1)) - logger.info(s"Test print large ndarray:\n$nd") + require(nd.toString.split("\n").length == 33) nd = NDArray.ones(Shape(1, 4)) - logger.info(s"Test print small ndarray:\n$nd") + require(nd.toString.split("\n").length == 4) } test("plus") { From 6d532aa907c6bdeb90ff7a13a98191822814ad1b Mon Sep 17 00:00:00 2001 From: Qing Date: Thu, 24 Jan 2019 15:30:09 -0800 Subject: [PATCH 10/11] add changes on names and tests --- .../org/apache/mxnet/MX_PRIMITIVES.scala | 2 +- .../main/scala/org/apache/mxnet/NDArray.scala | 22 +++++---- .../scala/org/apache/mxnet/NDArraySuite.scala | 45 ++++++++++++++++++- 3 files changed, 57 insertions(+), 12 deletions(-) diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/MX_PRIMITIVES.scala b/scala-package/core/src/main/scala/org/apache/mxnet/MX_PRIMITIVES.scala index 2b92d6e4096a..3a51222cc0b8 100644 --- a/scala-package/core/src/main/scala/org/apache/mxnet/MX_PRIMITIVES.scala +++ b/scala-package/core/src/main/scala/org/apache/mxnet/MX_PRIMITIVES.scala @@ -82,7 +82,7 @@ object MX_PRIMITIVES { implicit def MX_DoubleToDouble(d: MX_Double) : Double = d.data - def isValidType(num : Any) : Boolean = { + def isValidMxPrimitiveType(num : Any) : Boolean = { num match { case valid @ (_: Float | _: Double) => true case _ => false diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/NDArray.scala b/scala-package/core/src/main/scala/org/apache/mxnet/NDArray.scala index e9007cf10b3a..6f567cd8216d 100644 --- a/scala-package/core/src/main/scala/org/apache/mxnet/NDArray.scala +++ b/scala-package/core/src/main/scala/org/apache/mxnet/NDArray.scala @@ -515,16 +515,18 @@ object NDArray extends NDArrayBase { * @param sourceArr Array[Array...Array[MX_PRIMITIVE_TYPE]...] * @param ctx context like to pass in * @return an NDArray with the same shape of the input + * @throws IllegalArgumentException if the data type is not valid */ def toNDArray(sourceArr: Array[_], ctx : Context = null) : NDArray = { val shape = ArrayBuffer[Int]() shapeGetter(sourceArr, shape, 0) val container = new Array[Any](shape.product) - arrayCombiner(sourceArr, container, 0, container.length - 1) + flattenArray(sourceArr, container, 0, container.length - 1) val finalArr = container(0) match { case f: Float => array(container.map(_.asInstanceOf[Float]), Shape(shape), ctx) case d: Double => array(container.map(_.asInstanceOf[Double]), Shape(shape), ctx) - case _ => throw new IllegalArgumentException(s"Unsupported type ${container(0).getClass}") + case _ => throw new IllegalArgumentException( + s"Unsupported type ${container(0).getClass}, please check MX_PRIMITIVES for valid types") } finalArr } @@ -532,7 +534,7 @@ object NDArray extends NDArrayBase { private def shapeGetter(sourceArr : Any, shape : ArrayBuffer[Int], shapeIdx : Int) : Unit = { sourceArr match { - case arr: Array[_] if MX_PRIMITIVES.isValidType(arr(0)) => { + case arr: Array[_] if MX_PRIMITIVES.isValidMxPrimitiveType(arr(0)) => { val arrLength = arr.length if (shape.length == shapeIdx) { shape += arrLength @@ -552,16 +554,16 @@ object NDArray extends NDArrayBase { } } - private def arrayCombiner(sourceArr : Any, arr : Array[Any], + private def flattenArray(sourceArr : Any, arr : Array[Any], start : Int, end : Int) : Unit = { sourceArr match { - case arrValid: Array[_] if MX_PRIMITIVES.isValidType(arrValid(0)) => { + case arrValid: Array[_] if MX_PRIMITIVES.isValidMxPrimitiveType(arrValid(0)) => { for (i <- arrValid.indices) arr(start + i) = arrValid(i) } case arrAny: Array[_] => { val fragment = (end - start + 1) / arrAny.length for (i <- arrAny.indices) - arrayCombiner(arrAny(i), arr, start + i * fragment, start + (i + 1) * fragment) + flattenArray(arrAny(i), arr, start + i * fragment, start + (i + 1) * fragment) } case _ => throw new IllegalArgumentException(s"Wrong type passed: ${sourceArr.getClass}") } @@ -751,8 +753,10 @@ class NDArray private[mxnet](private[mxnet] val handle: NDArrayHandle, // we use weak reference to prevent gc blocking private[mxnet] val dependencies = mutable.HashMap.empty[Long, WeakReference[NDArray]] - private val traceProperty = "mxnet.setNDArrayPrintLength" - private lazy val printLength = Try(System.getProperty(traceProperty).toInt).getOrElse(1000) + private val lengthProperty = "mxnet.setNDArrayPrintLength" + private val layerProperty = "mxnet.setNDArrayPrintLayerLength" + private lazy val printLength = Try(System.getProperty(lengthProperty).toInt).getOrElse(1000) + private lazy val layerLength = Try(System.getProperty(layerProperty).toInt).getOrElse(10) def serialize(): Array[Byte] = { val buf = ArrayBuffer.empty[Byte] @@ -842,7 +846,7 @@ class NDArray private[mxnet](private[mxnet] val handle: NDArrayHandle, */ private def buildStringHelper(nd : NDArray, totalSpace : Int) : String = { var result = "" - val THRESHOLD = 10 // longest NDArray[NDArray[...]] to show in full + val THRESHOLD = layerLength // longest NDArray[NDArray[...]] to show in full val ARRAYTHRESHOLD = printLength // longest array to show in full val shape = nd.shape val space = totalSpace - shape.length diff --git a/scala-package/core/src/test/scala/org/apache/mxnet/NDArraySuite.scala b/scala-package/core/src/test/scala/org/apache/mxnet/NDArraySuite.scala index 1974bc8b9d89..890e85806759 100644 --- a/scala-package/core/src/test/scala/org/apache/mxnet/NDArraySuite.scala +++ b/scala-package/core/src/test/scala/org/apache/mxnet/NDArraySuite.scala @@ -125,9 +125,50 @@ class NDArraySuite extends FunSuite with BeforeAndAfterAll with Matchers { test("test Visualize") { var nd = NDArray.ones(Shape(1, 2, 1000, 1)) - require(nd.toString.split("\n").length == 33) + var data : String = + """ + |[ + | [ + | [ + | [1.0] + | [1.0] + | [1.0] + | [1.0] + | [1.0] + | [1.0] + | [1.0] + | [1.0] + | [1.0] + | [1.0] + | + | ... with length 1000 + | ] + | [ + | [1.0] + | [1.0] + | [1.0] + | [1.0] + | [1.0] + | [1.0] + | [1.0] + | [1.0] + | [1.0] + | [1.0] + | + | ... with length 1000 + | ] + | ] + |] + |""".stripMargin + require(nd.toString.split("\\s+").mkString == data.split("\\s+").mkString) nd = NDArray.ones(Shape(1, 4)) - require(nd.toString.split("\n").length == 4) + data = + """ + |[ + | [1.0,1.0,1.0,1.0] + |] + |""".stripMargin + require(nd.toString.split("\\s+").mkString == data.split("\\s+").mkString) } test("plus") { From 6496ed67e6483fb0e59d39def0ae83c861c96557 Mon Sep 17 00:00:00 2001 From: Qing Date: Thu, 24 Jan 2019 17:27:53 -0800 Subject: [PATCH 11/11] make code elegant and improve readability --- .../main/scala/org/apache/mxnet/NDArray.scala | 26 ++++++++----------- .../scala/org/apache/mxnet/NDArraySuite.scala | 14 ++++------ 2 files changed, 16 insertions(+), 24 deletions(-) diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/NDArray.scala b/scala-package/core/src/main/scala/org/apache/mxnet/NDArray.scala index 6f567cd8216d..5c345f21faf4 100644 --- a/scala-package/core/src/main/scala/org/apache/mxnet/NDArray.scala +++ b/scala-package/core/src/main/scala/org/apache/mxnet/NDArray.scala @@ -518,8 +518,7 @@ object NDArray extends NDArrayBase { * @throws IllegalArgumentException if the data type is not valid */ def toNDArray(sourceArr: Array[_], ctx : Context = null) : NDArray = { - val shape = ArrayBuffer[Int]() - shapeGetter(sourceArr, shape, 0) + val shape = shapeGetter(sourceArr) val container = new Array[Any](shape.product) flattenArray(sourceArr, container, 0, container.length - 1) val finalArr = container(0) match { @@ -531,24 +530,21 @@ object NDArray extends NDArrayBase { finalArr } - private def shapeGetter(sourceArr : Any, - shape : ArrayBuffer[Int], shapeIdx : Int) : Unit = { + private def shapeGetter(sourceArr : Any) : ArrayBuffer[Int] = { sourceArr match { + // e.g : Array[Double] the inner layer case arr: Array[_] if MX_PRIMITIVES.isValidMxPrimitiveType(arr(0)) => { - val arrLength = arr.length - if (shape.length == shapeIdx) { - shape += arrLength - } - require(shape(shapeIdx) == arrLength, "Each Array should have equal length") + ArrayBuffer[Int](arr.length) } + // e.g : Array[Array...[]] case arr: Array[_] => { - val arrLength = arr.length - if (shape.length == shapeIdx) { - shape += arrLength + var arrBuffer = new ArrayBuffer[Int]() + if (!arr.isEmpty) arrBuffer = shapeGetter(arr(0)) + for (idx <- arr.indices) { + require(arrBuffer == shapeGetter(arr(idx))) } - require(shape(shapeIdx) == arrLength, - s"Each Array should have equal length, expected ${shape(shapeIdx)}, get $arrLength") - arr.foreach(ele => shapeGetter(ele, shape, shapeIdx + 1)) + arrBuffer.insert(0, arr.length) + arrBuffer } case _ => throw new IllegalArgumentException(s"Wrong type passed: ${sourceArr.getClass}") } diff --git a/scala-package/core/src/test/scala/org/apache/mxnet/NDArraySuite.scala b/scala-package/core/src/test/scala/org/apache/mxnet/NDArraySuite.scala index 890e85806759..054300e952a8 100644 --- a/scala-package/core/src/test/scala/org/apache/mxnet/NDArraySuite.scala +++ b/scala-package/core/src/test/scala/org/apache/mxnet/NDArraySuite.scala @@ -91,23 +91,19 @@ class NDArraySuite extends FunSuite with BeforeAndAfterAll with Matchers { test("create NDArray based on Java Matrix") { def arrayGen(num : Any) : Array[Any] = { - val arrayBuf = num match { + val array = num match { case f: Float => - val arr = ArrayBuffer[Array[Float]]() - for (_ <- 0 until 100) arr += Array(1.0f, 1.0f, 1.0f, 1.0f) - arr + (for (_ <- 0 until 100) yield Array(1.0f, 1.0f, 1.0f, 1.0f)).toArray case d: Double => - val arr = ArrayBuffer[Array[Double]]() - for (_ <- 0 until 100) arr += Array(1.0d, 1.0d, 1.0d, 1.0d) - arr + (for (_ <- 0 until 100) yield Array(1.0d, 1.0d, 1.0d, 1.0d)).toArray case _ => throw new IllegalArgumentException(s"Unsupported Type ${num.getClass}") } Array( Array( - arrayBuf.toArray + array ), Array( - arrayBuf.toArray + array ) ) }