Skip to content
This repository was archived by the owner on Nov 17, 2023. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,13 @@ package org.apache.mxnet.javaapi

import collection.JavaConverters._

/**
* Constructing a context which is used to specify the device and device type that will
* be utilized by the engine.
*
* @param deviceTypeName {'cpu', 'gpu'} String representing the device type
* @param deviceId The device id of the device, needed for GPU
*/
class Context(val context: org.apache.mxnet.Context) {

val deviceTypeid: Int = context.deviceTypeid
Expand All @@ -26,6 +33,11 @@ class Context(val context: org.apache.mxnet.Context) {
= this(new org.apache.mxnet.Context(deviceTypeName, deviceId))

def withScope[T](body: => T): T = context.withScope(body)

/**
* Return device type of current context.
* @return device_type
*/
def deviceType: String = context.deviceType

override def toString: String = context.toString
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,5 +30,13 @@ object DataDesc{

implicit def toDataDesc(dataDesc: DataDesc): org.apache.mxnet.DataDesc = dataDesc.dataDesc

/**
* Get the dimension that corresponds to the batch size.
* @param layout layout string. For example, "NCHW".
* @return An axis indicating the batch_size dimension. When data-parallelism is used,
* the data will be automatically split and concatenate along the batch_size dimension.
* Axis can be -1, which means the whole array will be copied
* for each data-parallelism device.
*/
def getBatchAxis(layout: String): Int = org.apache.mxnet.DataDesc.getBatchAxis(Some(layout))
}
Original file line number Diff line number Diff line change
Expand Up @@ -29,27 +29,64 @@ object NDArray {

def waitall(): Unit = org.apache.mxnet.NDArray.waitall()

/**
* One hot encoding indices into matrix out.
* @param indices An NDArray containing indices of the categorical features.
* @param out The result holder of the encoding.
* @return Same as out.
*/
def onehotEncode(indices: NDArray, out: NDArray): NDArray
= org.apache.mxnet.NDArray.onehotEncode(indices, out)

/**
* Create an empty uninitialized new NDArray, with specified shape.
*
* @param shape shape of the NDArray.
* @param ctx The context of the NDArray.
*
* @return The created NDArray.
*/
def empty(shape: Shape, ctx: Context, dtype: DType.DType): NDArray
= org.apache.mxnet.NDArray.empty(shape, ctx, dtype)
def empty(ctx: Context, shape: Array[Int]): NDArray
= org.apache.mxnet.NDArray.empty(new Shape(shape), ctx)
def empty(ctx : Context, shape : java.util.List[java.lang.Integer]) : NDArray
= org.apache.mxnet.NDArray.empty(new Shape(shape), ctx)

/**
* Create a new NDArray filled with 0, with specified shape.
*
* @param shape shape of the NDArray.
* @param ctx The context of the NDArray.
*
* @return The created NDArray.
*/
def zeros(shape: Shape, ctx: Context, dtype: DType.DType): NDArray
= org.apache.mxnet.NDArray.zeros(shape, ctx, dtype)
def zeros(ctx: Context, shape: Array[Int]): NDArray
= org.apache.mxnet.NDArray.zeros(new Shape(shape), ctx)
def zeros(ctx : Context, shape : java.util.List[java.lang.Integer]) : NDArray
= org.apache.mxnet.NDArray.zeros(new Shape(shape), ctx)

/**
* Create a new NDArray filled with 1, with specified shape.
* @param shape shape of the NDArray.
* @param ctx The context of the NDArray.
* @return The created NDArray.
*/
def ones(shape: Shape, ctx: Context, dtype: DType.DType): NDArray
= org.apache.mxnet.NDArray.ones(shape, ctx, dtype)
def ones(ctx: Context, shape: Array[Int]): NDArray
= org.apache.mxnet.NDArray.ones(new Shape(shape), ctx)
def ones(ctx : Context, shape : java.util.List[java.lang.Integer]) : NDArray
= org.apache.mxnet.NDArray.ones(new Shape(shape), ctx)

/**
* Create a new NDArray filled with given value, with specified shape.
* @param shape shape of the NDArray.
* @param value value to be filled with
* @param ctx The context of the NDArray
*/
def full(shape: Shape, value: Float, ctx: Context): NDArray
= org.apache.mxnet.NDArray.full(shape, value, ctx)

Expand All @@ -65,37 +102,102 @@ object NDArray {
def minimum(lhs: NDArray, rhs: Float): NDArray = org.apache.mxnet.NDArray.minimum(lhs, rhs)
def minimum(lhs: Float, rhs: NDArray): NDArray = org.apache.mxnet.NDArray.minimum(lhs, rhs)


/**
* Returns the result of element-wise **equal to** (==) comparison operation with broadcasting.
* For each element in input arrays, return 1(true) if corresponding elements are same,
* otherwise return 0(false).
*/
def equal(lhs: NDArray, rhs: NDArray): NDArray = org.apache.mxnet.NDArray.equal(lhs, rhs)
def equal(lhs: NDArray, rhs: Float): NDArray = org.apache.mxnet.NDArray.equal(lhs, rhs)

/**
* Returns the result of element-wise **not equal to** (!=) comparison operation
* with broadcasting.
* For each element in input arrays, return 1(true) if corresponding elements are different,
* otherwise return 0(false).
*/
def notEqual(lhs: NDArray, rhs: NDArray): NDArray = org.apache.mxnet.NDArray.notEqual(lhs, rhs)
def notEqual(lhs: NDArray, rhs: Float): NDArray = org.apache.mxnet.NDArray.notEqual(lhs, rhs)

/**
* Returns the result of element-wise **greater than** (>) comparison operation
* with broadcasting.
* For each element in input arrays, return 1(true) if lhs elements are greater than rhs,
* otherwise return 0(false).
*/
def greater(lhs: NDArray, rhs: NDArray): NDArray = org.apache.mxnet.NDArray.greater(lhs, rhs)
def greater(lhs: NDArray, rhs: Float): NDArray = org.apache.mxnet.NDArray.greater(lhs, rhs)

/**
* Returns the result of element-wise **greater than or equal to** (>=) comparison
* operation with broadcasting.
* For each element in input arrays, return 1(true) if lhs elements are greater than equal to rhs
* otherwise return 0(false).
*/
def greaterEqual(lhs: NDArray, rhs: NDArray): NDArray
= org.apache.mxnet.NDArray.greaterEqual(lhs, rhs)
def greaterEqual(lhs: NDArray, rhs: Float): NDArray
= org.apache.mxnet.NDArray.greaterEqual(lhs, rhs)

/**
* Returns the result of element-wise **lesser than** (<) comparison operation
* with broadcasting.
* For each element in input arrays, return 1(true) if lhs elements are less than rhs,
* otherwise return 0(false).
*/
def lesser(lhs: NDArray, rhs: NDArray): NDArray = org.apache.mxnet.NDArray.lesser(lhs, rhs)
def lesser(lhs: NDArray, rhs: Float): NDArray = org.apache.mxnet.NDArray.lesser(lhs, rhs)

/**
* Returns the result of element-wise **lesser than or equal to** (<=) comparison
* operation with broadcasting.
* For each element in input arrays, return 1(true) if lhs elements are
* lesser than equal to rhs, otherwise return 0(false).
*/
def lesserEqual(lhs: NDArray, rhs: NDArray): NDArray
= org.apache.mxnet.NDArray.lesserEqual(lhs, rhs)
def lesserEqual(lhs: NDArray, rhs: Float): NDArray
= org.apache.mxnet.NDArray.lesserEqual(lhs, rhs)

/**
* Create a new NDArray that copies content from source_array.
* @param sourceArr Source data to create NDArray from.
* @param shape shape of the NDArray
* @param ctx The context of the NDArray, default to current default context.
* @return The created NDArray.
*/
def array(sourceArr: java.util.List[java.lang.Float], shape: Shape, ctx: Context = null): NDArray
= org.apache.mxnet.NDArray.array(
sourceArr.asScala.map(ele => Float.unbox(ele)).toArray, shape, ctx)

/**
* Returns evenly spaced values within a given interval.
* Values are generated within the half-open interval [`start`, `stop`). In other
* words, the interval includes `start` but excludes `stop`.
* @param start Start of interval.
* @param stop End of interval.
* @param step Spacing between values.
* @param repeat Number of times to repeat each element.
* @param ctx Device context.
* @param dType The data type of the `NDArray`.
* @return NDArray of evenly spaced values in the specified range.
*/
def arange(start: Float, stop: Float, step: Float, repeat: Int,
ctx: Context, dType: DType.DType): NDArray =
org.apache.mxnet.NDArray.arange(start, Some(stop), step, repeat, ctx, dType)
}

/**
* NDArray object in mxnet.
* NDArray is basic ndarray/Tensor like data structure in mxnet. <br />
* <b>
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Or use it inside a try-with-resources() or user [[ResourceScope]] in a try-with-resources block

* NOTE: NDArray is stored in native memory. Use NDArray in a try-with-resources() construct
* or a [[ResourceScope]] in a try-with-resource to have them automatically disposed. You can
* explicitly control the lifetime of NDArray by calling dispose manually. Failure to do this
* will result in leaking native memory.
* </b>
*/
class NDArray(val nd : org.apache.mxnet.NDArray ) {

def this(arr : Array[Float], shape : Shape, ctx : Context) = {
Expand All @@ -108,28 +210,88 @@ class NDArray(val nd : org.apache.mxnet.NDArray ) {

def serialize() : Array[Byte] = nd.serialize()

/**
* Release the native memory. <br />
* The NDArrays it depends on will NOT be disposed. <br />
* The object shall never be used after it is disposed.
*/
def dispose() : Unit = nd.dispose()
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am contemplating if we ever need to expose this at all or should we direct users to use it along with ResourceScope to make the user experience better


/**
* Dispose all NDArrays who help to construct this array. <br />
* e.g. (a * b + c).disposeDeps() will dispose a, b, c (including their deps) and a * b
* @return this array
*/
def disposeDeps() : NDArray = nd.disposeDepsExcept()
// def disposeDepsExcept(arr : Array[NDArray]) : NDArray = nd.disposeDepsExcept()

/**
* Return a sliced NDArray that shares memory with current one.
* NDArray only support continuous slicing on axis 0
*
* @param start Starting index of slice.
* @param stop Finishing index of slice.
*
* @return a sliced NDArray that shares memory with current one.
*/
def slice(start : Int, stop : Int) : NDArray = nd.slice(start, stop)

/**
* Return a sliced NDArray at the ith position of axis0
* @param i
* @return a sliced NDArray that shares memory with current one.
*/
def slice (i : Int) : NDArray = nd.slice(i)

/**
* Return a sub NDArray that shares memory with current one.
* the first axis will be rolled up, which causes its shape different from slice(i, i+1)
* @param idx index of sub array.
*/
def at(idx : Int) : NDArray = nd.at(idx)

def T : NDArray = nd.T

/**
* Get data type of current NDArray.
* @return class representing type of current ndarray
*/
def dtype : DType = nd.dtype

/**
* Return a copied numpy array of current array with specified type.
* @param dtype Desired type of result array.
* @return A copy of array content.
*/
def asType(dtype : DType) : NDArray = nd.asType(dtype)

/**
* Return a reshaped NDArray that shares memory with current one.
* @param dims New shape.
*
* @return a reshaped NDArray that shares memory with current one.
*/
def reshape(dims : Array[Int]) : NDArray = nd.reshape(dims)

/**
* Block until all pending writes operations on current NDArray are finished.
* This function will return when all the pending writes to the current
* NDArray finishes. There can still be pending read going on when the
* function returns.
*/
def waitToRead(): Unit = nd.waitToRead()

/**
* Get context of current NDArray.
* @return The context of current NDArray.
*/
def context : Context = nd.context

/**
* Set the values of the NDArray
* @param value Value to set
* @return Current NDArray
*/
def set(value : Float) : NDArray = nd.set(value)
def set(other : NDArray) : NDArray = nd.set(other)
def set(other : Array[Float]) : NDArray = nd.set(other)
Expand Down Expand Up @@ -167,20 +329,57 @@ class NDArray(val nd : org.apache.mxnet.NDArray ) {
def lesserEqual(other : NDArray) : NDArray = this.nd <= other
def lesserEqual(other : Float) : NDArray = this.nd <= other

/**
* Return a copied flat java array of current array (row-major).
* @return A copy of array content.
*/
def toArray : Array[Float] = nd.toArray

/**
* Return a CPU scalar(float) of current ndarray.
* This ndarray must have shape (1,)
*
* @return The scalar representation of the ndarray.
*/
def toScalar : Float = nd.toScalar

/**
* Copy the content of current array to other.
*
* @param other Target NDArray or context we want to copy data to.
* @return The copy target NDArray
*/
def copyTo(other : NDArray) : NDArray = nd.copyTo(other)

/**
* Copy the content of current array to a new NDArray in the context.
*
* @param ctx Target context we want to copy data to.
* @return The copy target NDArray
*/
def copyTo(ctx : Context) : NDArray = nd.copyTo(ctx)

/**
* Clone the current array
* @return the copied NDArray in the same context
*/
def copy() : NDArray = copyTo(this.context)

/**
* Get shape of current NDArray.
* @return an array representing shape of current ndarray
*/
def shape : Shape = nd.shape


def size : Int = shape.product

/**
* Return an `NDArray` that lives in the target context. If the array
* is already in that context, `self` is returned. Otherwise, a copy is made.
* @param context The target context we want the return value to live in.
* @return A copy or `self` as an `NDArray` that lives in the target context.
*/
def asInContext(context: Context): NDArray = nd.asInContext(context)

override def equals(obj: Any): Boolean = nd.equals(obj)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,21 @@ import org.apache.mxnet.javaapi.{Context, DataDesc, NDArray}
import scala.collection.JavaConverters
import scala.collection.JavaConverters._


/**
* The ObjectDetector class helps to run ObjectDetection tasks where the goal
* is to find bounding boxes and corresponding labels for objects in a image.
*
* @param modelPathPrefix Path prefix from where to load the model artifacts.
* These include the symbol, parameters, and synset.txt.
* Example: file://model-dir/ssd_resnet50_512 (containing
* ssd_resnet50_512-symbol.json, ssd_resnet50_512-0000.params,
* and synset.txt)
* @param inputDescriptors Descriptors defining the input node names, shape,
* layout and type parameters
* @param contexts Device contexts on which you want to run inference.
* Defaults to CPU.
* @param epoch Model epoch to load; defaults to 0
*/
class ObjectDetector(val objDetector: org.apache.mxnet.infer.ObjectDetector){

def this(modelPathPrefix: String, inputDescriptors: java.util.List[DataDesc], contexts:
Expand Down
Loading