Skip to content
This repository was archived by the owner on Nov 17, 2023. It is now read-only.

[MXNET-913] Java API --- Scala NDArray Improvement#12536

Closed
lanking520 wants to merge 9 commits intoapache:masterfrom
lanking520:java-ndarray
Closed

[MXNET-913] Java API --- Scala NDArray Improvement#12536
lanking520 wants to merge 9 commits intoapache:masterfrom
lanking520:java-ndarray

Conversation

@lanking520
Copy link
Copy Markdown
Member

@lanking520 lanking520 commented Sep 12, 2018

Description

This PR contains some addition in NDArray as well as Java compatible functionalities for Scala package.

ArgBuilder: Allows user to pass in a single variable or a batch variables to construct a Scala Sequence or Scala Map.

NDArray:

  • Brings some improvement on NDArray such as a toString method that can allow user to see a formatted output.
  • Add new NDArray() constructor allowing Java users to create a new one through this way.
  • Operator improvement: Adding add, subtract methods to allow java user use them

Unit test is coming soon...

@nswamy @yzhliu @andrewfayres

Demo code

import org.apache.mxnet.*;
import org.apache.mxnet.api.java.*;


public class playjava {

    static final NDArray$ NDArray = NDArray$.MODULE$;
    static final Java2ScalaConversionKit$ J2S = Java2ScalaConversionKit$.MODULE$;

    public static void main(String[] args) {
                NDArray nd = new NDArray(new float[]{1.0f, 2.0f, 3.0f, 4.0f}, new Shape(new int[]{1, 4}), Context.cpu(0));
        System.out.println(nd.visualize());


        NDArray result = NDArray.norm(new ArgBuilder().addArg(nd).addArg(2).addArg(1).buildSeq()).get();
        System.out.println(result.visualize());

        result = NDArray.norm(new ArgBuilder().addArg("ord", 2).addArg("axis", 1).buildMap(),
                new ArgBuilder().addArg(nd).buildSeq()).get();
        System.out.println(result.visualize());

        result = NDArray.norm(new ArgBuilder().addBatchArgs(new Object[]{nd, 2, 1}).buildSeq()).get();
        System.out.println(result.visualize());

        NDArray.array(new float[]{1f, 2f, 3f}, new Shape(new int[]{1, 3}), null);

        NDArray nd2 = NDArray.zeros(new int[]{10000, 10000});
        System.out.println(nd2.visualize());
    }
}

Output:

[
 [1.0,2.0,3.0,4.0]
]

[5.477226]

[5.477226]

[5.477226]

[
 [0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0 ... 0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0]

 ... with length 10000
]

Checklist

Essentials

Please feel free to remove inapplicable items for your PR.

  • The PR title starts with [MXNET-$JIRA_ID], where $JIRA_ID refers to the relevant JIRA issue created (except PRs with tiny changes)
  • Changes are complete (i.e. I finished coding on this PR)
  • All changes have test coverage:
  • Unit tests are added for small changes to verify correctness (e.g. adding a new operator)
  • Nightly tests are added for complicated/long-running ones (e.g. changing distributed kvstore)
  • Build tests will be added for build configuration changes (e.g. adding a new build option with NCCL)
  • Code is well-documented:
  • For user-facing API changes, API doc string has been updated.
  • For new C++ functions in header files, their functionalities and arguments are documented.
  • For new examples, README.md is added to explain the what the example does, the source of the dataset, expected performance on test set and reference to the original paper if applicable
  • Check the API doc at http://mxnet-ci-doc.s3-accelerate.dualstack.amazonaws.com/PR-$PR_ID/$BUILD_ID/index.html
  • To the my best knowledge, examples are either not affected by this change, or have been fixed to be compatible with this change

@kalyc
Copy link
Copy Markdown
Contributor

kalyc commented Sep 14, 2018

Thanks for your contribution @lanking520
@mxnet-label-bot[pr-work-in-progress]

@marcoabreu marcoabreu added the pr-work-in-progress PR is still work in progress label Sep 14, 2018
@gigasquid
Copy link
Copy Markdown
Member

@lanking520 I really like to toString changes to improve readability but I wonder about the behavior as the matrices get large. Do we want to print out 1000x1000 arrays by default?

)
val output = List(
("org.apache.mxnet.Symbol", true),
("Int", false),
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

wouldn't this break backward compatibility?

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This test is just for Macros generation. As long as integration test passed, it didn't break the BC.

* This arg Builder is intent to solve Java to Scala conversion
* to take the input such as (arg: Any*)
*/
class ArgBuilder {
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why do we need this? why are you going back in time? we want to move away from (arg: Any*) to type-safe APIs

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think type-safe is not as important as ease of usage. For type-safe in Java, the major question is how to deal with default args.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess Java does not get defaults. @andrewfayres rightly asked how many APIs do we have that have a large number of parameters to matter? may be we just add builder to those if they are only a few and for others they can pass Scala Option.None or change the APIs to accept gauva Optional.
@lanking520 can you find how many have more than 5 parameters?

Copy link
Copy Markdown
Member Author

@lanking520 lanking520 Sep 17, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

After running this:

    printf(s"\n\n\n\nTotal numbers " +
      ndarrayFunctions.count(_.listOfArgs.length > 5)
      + " out of " + ndarrayFunctions.length + "\n\n\n\n"
    )

get output

Total numbers 64 out of 665

However, we should consider out as a param in Type-safe API param. Counting this we got

Total numbers 101 out of 665

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's not the total number of parameters that matters, it's how many default parameters there are in the method. If a method has 5 parameters and no defaults then the builder doesn't help any. If there are 5 parameters and all are defaults then it helps a lot.

@lanking520 Can we get a count of how many methods have more than 3 default args? If possible what I'd really like is a distribution (x methods have 1 default arg, y have 2, ...) but if this is too difficult I understand.

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

printf(s"\n\n\n\nTotal numbers " +
      ndarrayFunctions.count(_.listOfArgs.count(_.isOptional) > 3)
      + " out of " + ndarrayFunctions.length + "\n\n\n\n"
    )

Here you go

Total numbers 65 out of 665

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So about 10% of NDArray methods have more than 3 default args.

I'm going to give some more thought before I commit to this but my initial reaction is that we include this but try not to promote it's use too much in docs/examples. Leave it there as an ease of use option for the customer with the understanding that when they use this they will be giving up type-safety.

@lanking520
Copy link
Copy Markdown
Member Author

lanking520 commented Sep 17, 2018

@gigasquid in order to solve this issue. I have done two things:

  1. Restricted toArray method with maximum length to 100000
  2. Rewrite the toString method to optimize memory space as well as take a small portion of large output if the output array is too large

@nswamy
Copy link
Copy Markdown
Member

nswamy commented Sep 18, 2018

@lanking520 may be just printing the NDArray metadata in toString might be suffient
(x.shape, x.size, x.dtype, x.context)

  • What is the use-case you are trying to solve with toString() ?
  • I don't think you should threshold on maxLength in toArray(), it will be confusing. there might be genuine cases where you want to move from Off-Heap(NDArray) to JVM Heap(Array[Float])

@lanking520
Copy link
Copy Markdown
Member Author

@lanking520 may be just printing the NDArray metadata in toString might be suffient
(x.shape, x.size, x.dtype, x.context)

  • What is the use-case you are trying to solve with toString() ?
  • I don't think you should threshold on maxLength in toArray(), it will be confusing. there might be genuine cases where you want to move from Off-Heap(NDArray) to JVM Heap(Array[Float])

Hi @nswamy , the point for toString method is to present something similar to what Python ndarray have. Usually when you print a python ndarray, it will give you an overview of the datastructure.

The reason I set a limitation there is because I am facing this issue while I am trying to get a toArray result from a 1000000 element NDArray. It reported the error that JVM memory used up when we do this kind of operation. In that case, user should think of reducing the dimension, slicing the array or even write them into files.

* @return A copy of array content.
*/
def toArray: Array[Float] = {
require(shape.toArray.product < 1000000,
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we shouldn't add arbitrary limitations.

}

override def toString: String = {
buildStringHelper(this, this.shape.length) + "\n"
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

print NDArray's metadata instead
(x.shape, x.size, x.dtype, x.context, x.handle(%x)->hex value of native mem ref)

@lanking520 lanking520 changed the title [MXNET-913][WIP] Java API --- Scala NDArray Improvement [MXNET-913] Java API --- Scala NDArray Improvement Sep 20, 2018
@vandanavk
Copy link
Copy Markdown
Contributor

@lanking520 Could you check the build failure

@vrakesh
Copy link
Copy Markdown
Contributor

vrakesh commented Oct 9, 2018

@lanking520 Requesting an update on the PR, is the build failure issue resolved?

@lanking520
Copy link
Copy Markdown
Member Author

Stop the experiment and ship with this solution #12772

@lanking520 lanking520 closed this Oct 9, 2018
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.

Labels

pr-work-in-progress PR is still work in progress

Projects

None yet

Development

Successfully merging this pull request may close these issues.

9 participants