Skip to content
This repository was archived by the owner on Nov 17, 2023. It is now read-only.
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -248,6 +248,11 @@ object NDArray extends NDArrayBase {

def ones(ctx: Context, shape: Int *): NDArray = ones(Shape(shape: _*), ctx)

// Java compatible conversion methods
def empty(shape: Array[Int]): NDArray = empty(Shape(shape))
def zeros(shape: Array[Int]): NDArray = zeros(Shape(shape))
def ones(shape: Array[Int]) : NDArray = ones(Shape(shape))

/**
* Create a new NDArray filled with given value, with specified shape.
* @param shape shape of the NDArray.
Expand Down Expand Up @@ -567,6 +572,18 @@ class NDArray private[mxnet](private[mxnet] val handle: NDArrayHandle,
NDArrayCollector.collect(this)
}

/**
* Java Flavor creating new NDArray
* @param arr
* @param shape
* @param ctx
* @return
*/
def this(arr : Array[Float], shape : Shape, ctx : Context) = {
this(NDArray.newAllocHandle(shape, ctx, delayAlloc = false, Base.MX_REAL_TYPE))
this.set(arr)
}

// record arrays who construct this array instance
// we use weak reference to prevent gc blocking
private[mxnet] val dependencies = mutable.HashMap.empty[Long, WeakReference[NDArray]]
Expand Down Expand Up @@ -940,6 +957,44 @@ class NDArray private[mxnet](private[mxnet] val handle: NDArrayHandle,
this
}

/* Java Compatibility Functions
Function name with underscore means
it is going to do the operator as well as
update itself such as +=
*/
def add(other : NDArray) : NDArray = this + other
def add(other : Float) : NDArray = this + other
def _add(other : NDArray) : NDArray = this += other
def _add(other : Float) : NDArray = this += other
def subtract(other : NDArray) : NDArray = this - other
def subtract(other : Float) : NDArray = this - other
def _subtract(other : NDArray) : NDArray = this -= other
def _subtract(other : Float) : NDArray = this -= other
def multiply(other : NDArray) : NDArray = this * other
def multiply(other : Float) : NDArray = this * other
def _multiply(other : NDArray) : NDArray = this *= other
def _multiply(other : Float) : NDArray = this *= other
def div(other : NDArray) : NDArray = this / other
def div(other : Float) : NDArray = this / other
def _div(other : NDArray) : NDArray = this /= other
def _div(other : Float) : NDArray = this /= other
def pow(other : NDArray) : NDArray = this ** other
def pow(other : Float) : NDArray = this ** other
def _pow(other : NDArray) : NDArray = this **= other
def _pow(other : Float) : NDArray = this **= other
def mod(other : NDArray) : NDArray = this % other
def mod(other : Float) : NDArray = this % other
def _mod(other : NDArray) : NDArray = this %= other
def _mod(other : Float) : NDArray = this %= other
def greater(other : NDArray) : NDArray = this > other
def greater(other : Float) : NDArray = this > other
def greaterEqual(other : NDArray) : NDArray = this >= other
def greaterEqual(other : Float) : NDArray = this >= other
def lesser(other : NDArray) : NDArray = this < other
def lesser(other : Float) : NDArray = this < other
def lesserEqual(other : NDArray) : NDArray = this <= other
def lesserEqual(other : Float) : NDArray = this <= other

/**
* Return a copied flat java array of current array (row-major).
* @return A copy of array content.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,15 @@ class Shape(dims: Traversable[Int]) extends Serializable {
this(dims.toVector)
}

/**
* Java compatible constructor
* @param dims Array of Int input
* @return Shape
*/
def this(dims: Array[Int]) = {
this(dims.toVector)
}

def apply(dim: Int): Int = shape(dim)
def get(dim: Int): Int = apply(dim)
def size: Int = shape.size
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.mxnet.api.java

import scala.collection.mutable
import scala.collection.mutable.ListBuffer
import collection.JavaConverters._

/**
* This arg Builder is intent to solve Java to Scala conversion
* to take the input such as (arg: Any*)
*/
class ArgBuilder {
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why do we need this? why are you going back in time? we want to move away from (arg: Any*) to type-safe APIs

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think type-safe is not as important as ease of usage. For type-safe in Java, the major question is how to deal with default args.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess Java does not get defaults. @andrewfayres rightly asked how many APIs do we have that have a large number of parameters to matter? may be we just add builder to those if they are only a few and for others they can pass Scala Option.None or change the APIs to accept gauva Optional.
@lanking520 can you find how many have more than 5 parameters?

Copy link
Copy Markdown
Member Author

@lanking520 lanking520 Sep 17, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

After running this:

    printf(s"\n\n\n\nTotal numbers " +
      ndarrayFunctions.count(_.listOfArgs.length > 5)
      + " out of " + ndarrayFunctions.length + "\n\n\n\n"
    )

get output

Total numbers 64 out of 665

However, we should consider out as a param in Type-safe API param. Counting this we got

Total numbers 101 out of 665

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's not the total number of parameters that matters, it's how many default parameters there are in the method. If a method has 5 parameters and no defaults then the builder doesn't help any. If there are 5 parameters and all are defaults then it helps a lot.

@lanking520 Can we get a count of how many methods have more than 3 default args? If possible what I'd really like is a distribution (x methods have 1 default arg, y have 2, ...) but if this is too difficult I understand.

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

printf(s"\n\n\n\nTotal numbers " +
      ndarrayFunctions.count(_.listOfArgs.count(_.isOptional) > 3)
      + " out of " + ndarrayFunctions.length + "\n\n\n\n"
    )

Here you go

Total numbers 65 out of 665

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So about 10% of NDArray methods have more than 3 default args.

I'm going to give some more thought before I commit to this but my initial reaction is that we include this but try not to promote it's use too much in docs/examples. Leave it there as an ease of use option for the customer with the understanding that when they use this they will be giving up type-safety.

private var data = ListBuffer[Any]()
private var map = mutable.Map[String, Any]()

def addArg(anyRef: AnyRef): ArgBuilder = {
require(map.isEmpty,
"Map is not empty, please do either key-value or positional-arg but not both")
this.data += anyRef.asInstanceOf[Any]
this
}

def addArg(key : String, value : AnyRef) : ArgBuilder = {
require(data.isEmpty,
"Data is not empty, please do either key-value or positional-arg but not both")
this.map(key) = value.asInstanceOf[Any]
this
}

def addBatchArgs(list : java.util.List[AnyRef]) : ArgBuilder = {
require(map.isEmpty,
"Map is not empty, please do either key-value or positional-arg but not both")
for (i <- 0 to list.size()) {
this.data += list.get(i)
}
this
}

def addBatchArgs(arr : Array[AnyRef]) : ArgBuilder = {
require(map.isEmpty,
"Map is not empty, please do either key-value or positional-arg but not both")
arr.foreach(ele => this.data += ele)
this
}

def buildMap() : Map[String, Any] = {
this.map.toMap
}

def buildSeq() : Seq[Any] = {
this.data
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,11 @@ class NDArraySuite extends FunSuite with BeforeAndAfterAll with Matchers {
assert(ndones.toScalar === 1f)
}

test("new NDArray") {
val ndarray = new NDArray(Array(1.0f, 2.0f), Shape(1, 2), Context.cpu())
assert(ndarray.shape == Shape(1, 2))
}

test ("call toScalar on an ndarray which is not a scalar") {
intercept[Exception] { NDArray.zeros(1, 1).toScalar }
}
Expand Down