diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/NDArray.scala b/scala-package/core/src/main/scala/org/apache/mxnet/NDArray.scala index 9b6a7dc66540..dc1273315b24 100644 --- a/scala-package/core/src/main/scala/org/apache/mxnet/NDArray.scala +++ b/scala-package/core/src/main/scala/org/apache/mxnet/NDArray.scala @@ -248,6 +248,11 @@ object NDArray extends NDArrayBase { def ones(ctx: Context, shape: Int *): NDArray = ones(Shape(shape: _*), ctx) + // Java compatible conversion methods + def empty(shape: Array[Int]): NDArray = empty(Shape(shape)) + def zeros(shape: Array[Int]): NDArray = zeros(Shape(shape)) + def ones(shape: Array[Int]) : NDArray = ones(Shape(shape)) + /** * Create a new NDArray filled with given value, with specified shape. * @param shape shape of the NDArray. @@ -567,6 +572,18 @@ class NDArray private[mxnet](private[mxnet] val handle: NDArrayHandle, NDArrayCollector.collect(this) } + /** + * Java Flavor creating new NDArray + * @param arr + * @param shape + * @param ctx + * @return + */ + def this(arr : Array[Float], shape : Shape, ctx : Context) = { + this(NDArray.newAllocHandle(shape, ctx, delayAlloc = false, Base.MX_REAL_TYPE)) + this.set(arr) + } + // record arrays who construct this array instance // we use weak reference to prevent gc blocking private[mxnet] val dependencies = mutable.HashMap.empty[Long, WeakReference[NDArray]] @@ -940,6 +957,44 @@ class NDArray private[mxnet](private[mxnet] val handle: NDArrayHandle, this } + /* Java Compatibility Functions + Function name with underscore means + it is going to do the operator as well as + update itself such as += + */ + def add(other : NDArray) : NDArray = this + other + def add(other : Float) : NDArray = this + other + def _add(other : NDArray) : NDArray = this += other + def _add(other : Float) : NDArray = this += other + def subtract(other : NDArray) : NDArray = this - other + def subtract(other : Float) : NDArray = this - other + def _subtract(other : NDArray) : NDArray = this -= other + def _subtract(other : Float) : NDArray = this -= other + def multiply(other : NDArray) : NDArray = this * other + def multiply(other : Float) : NDArray = this * other + def _multiply(other : NDArray) : NDArray = this *= other + def _multiply(other : Float) : NDArray = this *= other + def div(other : NDArray) : NDArray = this / other + def div(other : Float) : NDArray = this / other + def _div(other : NDArray) : NDArray = this /= other + def _div(other : Float) : NDArray = this /= other + def pow(other : NDArray) : NDArray = this ** other + def pow(other : Float) : NDArray = this ** other + def _pow(other : NDArray) : NDArray = this **= other + def _pow(other : Float) : NDArray = this **= other + def mod(other : NDArray) : NDArray = this % other + def mod(other : Float) : NDArray = this % other + def _mod(other : NDArray) : NDArray = this %= other + def _mod(other : Float) : NDArray = this %= other + def greater(other : NDArray) : NDArray = this > other + def greater(other : Float) : NDArray = this > other + def greaterEqual(other : NDArray) : NDArray = this >= other + def greaterEqual(other : Float) : NDArray = this >= other + def lesser(other : NDArray) : NDArray = this < other + def lesser(other : Float) : NDArray = this < other + def lesserEqual(other : NDArray) : NDArray = this <= other + def lesserEqual(other : Float) : NDArray = this <= other + /** * Return a copied flat java array of current array (row-major). * @return A copy of array content. diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/Shape.scala b/scala-package/core/src/main/scala/org/apache/mxnet/Shape.scala index 689176217722..772ec8e262c1 100644 --- a/scala-package/core/src/main/scala/org/apache/mxnet/Shape.scala +++ b/scala-package/core/src/main/scala/org/apache/mxnet/Shape.scala @@ -29,6 +29,15 @@ class Shape(dims: Traversable[Int]) extends Serializable { this(dims.toVector) } + /** + * Java compatible constructor + * @param dims Array of Int input + * @return Shape + */ + def this(dims: Array[Int]) = { + this(dims.toVector) + } + def apply(dim: Int): Int = shape(dim) def get(dim: Int): Int = apply(dim) def size: Int = shape.size diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/api/java/ArgBuilder.scala b/scala-package/core/src/main/scala/org/apache/mxnet/api/java/ArgBuilder.scala new file mode 100644 index 000000000000..4b478e5fc5dd --- /dev/null +++ b/scala-package/core/src/main/scala/org/apache/mxnet/api/java/ArgBuilder.scala @@ -0,0 +1,69 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mxnet.api.java + +import scala.collection.mutable +import scala.collection.mutable.ListBuffer +import collection.JavaConverters._ + +/** + * This arg Builder is intent to solve Java to Scala conversion + * to take the input such as (arg: Any*) + */ +class ArgBuilder { + private var data = ListBuffer[Any]() + private var map = mutable.Map[String, Any]() + + def addArg(anyRef: AnyRef): ArgBuilder = { + require(map.isEmpty, + "Map is not empty, please do either key-value or positional-arg but not both") + this.data += anyRef.asInstanceOf[Any] + this + } + + def addArg(key : String, value : AnyRef) : ArgBuilder = { + require(data.isEmpty, + "Data is not empty, please do either key-value or positional-arg but not both") + this.map(key) = value.asInstanceOf[Any] + this + } + + def addBatchArgs(list : java.util.List[AnyRef]) : ArgBuilder = { + require(map.isEmpty, + "Map is not empty, please do either key-value or positional-arg but not both") + for (i <- 0 to list.size()) { + this.data += list.get(i) + } + this + } + + def addBatchArgs(arr : Array[AnyRef]) : ArgBuilder = { + require(map.isEmpty, + "Map is not empty, please do either key-value or positional-arg but not both") + arr.foreach(ele => this.data += ele) + this + } + + def buildMap() : Map[String, Any] = { + this.map.toMap + } + + def buildSeq() : Seq[Any] = { + this.data + } +} diff --git a/scala-package/core/src/test/scala/org/apache/mxnet/NDArraySuite.scala b/scala-package/core/src/test/scala/org/apache/mxnet/NDArraySuite.scala index 5d88bb39e502..edbc98e5a11a 100644 --- a/scala-package/core/src/test/scala/org/apache/mxnet/NDArraySuite.scala +++ b/scala-package/core/src/test/scala/org/apache/mxnet/NDArraySuite.scala @@ -38,6 +38,11 @@ class NDArraySuite extends FunSuite with BeforeAndAfterAll with Matchers { assert(ndones.toScalar === 1f) } + test("new NDArray") { + val ndarray = new NDArray(Array(1.0f, 2.0f), Shape(1, 2), Context.cpu()) + assert(ndarray.shape == Shape(1, 2)) + } + test ("call toScalar on an ndarray which is not a scalar") { intercept[Exception] { NDArray.zeros(1, 1).toScalar } }