This repository was archived by the owner on Nov 17, 2023. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 6.7k
[2.0] Gluon2.0: switch to use forward interface #20262
Merged
Merged
Changes from all commits
Commits
Show all changes
85 commits
Select commit
Hold shift + click to select a range
2fcc643
use forward
barry-jin d70041a
fix lint
barry-jin 6a28465
update
barry-jin ee41074
update
barry-jin 74a668b
fix lint
barry-jin d0d9ff9
remove symbol related test from gluon2.0 blocks
barry-jin 687e366
fix lint
barry-jin 7c9ff36
remove legacy tests on gluon
barry-jin 07f1c26
fix lint
barry-jin cf3bb33
add group_norm in npx && fix some tests
barry-jin 2a8a79d
fix lint
barry-jin fa3992e
gluon2.0 rnn
barry-jin a2f4f3b
fix lint
barry-jin ddad98a
update rnn_layer infer_shape
barry-jin a0b625e
gluon2.0 probability clean up F
barry-jin 47ba77e
update probability
barry-jin 2f7ef9f
fix lint
barry-jin 2a0b8e9
update
barry-jin ca2c1ff
use forward interface in tests
barry-jin f166621
merge master
barry-jin 1560a91
fix operator np.average
barry-jin c711aed
copy contrib.foreach/while_loop/cond contrib.sign_ste/round_ste to np…
barry-jin 7a07bee
fix tests
barry-jin 3a9914e
rnn states set context
barry-jin 37a92ef
fix lint
barry-jin 0fde664
update block.py && fix tests
barry-jin c9168b7
merge upstream
barry-jin cfde82c
fix lint
barry-jin 9b17bbf
fix
barry-jin b2bd01a
fix cpu tests
barry-jin bdbe4cb
update test_subgraph.py
barry-jin 00ebffe
update
barry-jin a34c2d4
update tests/python/train/test_autograd.py
barry-jin 2325873
work around
barry-jin 0b50090
skip test_slice_pooling2d_slice_pooling2d
barry-jin e14f9f7
update docstring
barry-jin 21c39fc
set np default
barry-jin 850baff
reset np for mkl tests
barry-jin 0b1955b
update
barry-jin 978f97e
try to fix possible memory leak
barry-jin e207ad3
update tests & turn on doctest
barry-jin 4835f65
update check_layer_forward_withinput
barry-jin 3a31387
turn off doctest
barry-jin b9b0468
update
barry-jin acf2d03
use mx.np.random
barry-jin 2cbed57
update
barry-jin dfd6e20
update check_layer_forward_withinput
barry-jin 770ab05
update
barry-jin d2de786
Revert "update"
barry-jin d344e4e
Revert "update check_layer_forward_withinput"
barry-jin 9d3cc66
add dc.clear() to recursively release the references
barry-jin 2776820
fix lint
barry-jin 6b373e2
clear cached_graph in destructor
barry-jin 2667afb
reset_np in test_subgraph_op.py
barry-jin c400f82
update block destructor
barry-jin 16b5fe2
Revert "update block destructor"
barry-jin 06d24e7
Revert "clear cached_graph in destructor"
barry-jin 1c0723d
fix reference leak
barry-jin 23ca022
Revert "fix reference leak"
barry-jin 8fb97db
Revert "Revert "clear cached_graph in destructor""
barry-jin 56ec03a
Revert "Revert "update block destructor""
barry-jin 7c64283
clear input symbols
barry-jin aaa092d
update imperative.cc
barry-jin 44a6a35
update tests
barry-jin 7aad8de
adapt and add back some tests
barry-jin 97ec57d
use np.concatenate
barry-jin d0995ca
fix lint
barry-jin 3108f0f
update rnn_cell.py
barry-jin 85eba35
update test_gluon_rnn.py
barry-jin 4f3a48a
clear dc info in GetDeferredComputeSymbol
barry-jin cd88301
merge master
barry-jin 2bd3ba0
Clear deferred compute node entry of output ndarrays
barry-jin 2b1ad5f
dc.clear
barry-jin 0145e63
update
barry-jin a7db1f3
Create DC compatible control flow operators in npx namespace
barry-jin 2cf07e6
fix lint'
barry-jin ec1f788
update
barry-jin aa3dd06
update
barry-jin f1e8aeb
update foreach operator
barry-jin 132c667
update rnn_cell.py
barry-jin 0a7d1da
add control flow operators in amp lists
barry-jin ed04d7e
dc.clear after creating graph with dc in control flow operators
barry-jin 3fc1dbb
Merge remote-tracking branch 'upstream/master' into issue-19138
barry-jin 4cc1f7e
upgrade test_quantization.py to use gluon2.0
barry-jin e9f185b
improve documentation
barry-jin File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -24,8 +24,8 @@ | |
| from time import time | ||
|
|
||
| import mxnet as mx | ||
| import numpy as np | ||
| from mxnet import gluon | ||
| import numpy as onp | ||
| from mxnet import gluon, np, npx | ||
|
|
||
|
|
||
| _parser = argparse.ArgumentParser(description='Benchmark foreach and while_loop on RNN tasks.') | ||
|
|
@@ -42,8 +42,8 @@ def __init__(self, cell, length, prefix=None, params=None): | |
| self.length = length | ||
| self.cell = cell | ||
|
|
||
| def hybrid_forward(self, F, inputs, states): | ||
| out, states = F.contrib.foreach(self.cell, inputs, states) | ||
| def forward(self, inputs, states): | ||
| out, states = npx.foreach(self.cell, inputs, states) | ||
| return out | ||
|
|
||
|
|
||
|
|
@@ -53,15 +53,15 @@ def __init__(self, cell, length, prefix=None, params=None): | |
| self.length = length | ||
| self.cell = cell | ||
|
|
||
| def hybrid_forward(self, F, inputs, states): | ||
| def forward(self, inputs, states): | ||
| def _func(*states): | ||
| i = states[0] | ||
| s = states[1: ] | ||
| data = inputs.take(i).squeeze(axis=0) | ||
| data = np.squeeze(np.take(inputs, i), axis=0) | ||
| out, new_s = self.cell(data, s) | ||
| new_s = [i + 1] + new_s | ||
| return out, new_s | ||
| out, states = F.contrib.while_loop( | ||
| out, states = npx.while_loop( | ||
| cond=lambda i, *_: i < self.length, | ||
| func=_func, | ||
| loop_vars=states, | ||
|
|
@@ -71,11 +71,11 @@ def _func(*states): | |
|
|
||
|
|
||
| def _zeros(shape, ctx): | ||
| return mx.nd.zeros(shape=shape, ctx=ctx) | ||
| return mx.np.zeros(shape=shape, ctx=ctx) | ||
|
|
||
|
|
||
| def _array(shape, ctx): | ||
| return mx.nd.normal(loc=0.0, scale=1.0, shape=shape, ctx=ctx) | ||
| return mx.np.random.normal(loc=0.0, scale=1.0, size=shape, ctx=ctx) | ||
|
|
||
|
|
||
| def _get_gpus(): | ||
|
|
@@ -107,11 +107,11 @@ def run_benchmark(cell_type, ctx, seq_len, batch_size, hidden_dim): | |
| res = layer(inputs, states) | ||
| if is_train: | ||
| res.backward() | ||
| mx.nd.waitall() | ||
| mx.npx.waitall() | ||
| tock = time() | ||
| times.append((tock - tick) * 1000.0) | ||
| times = times[args.warmup_rounds: ] | ||
| print("Time used: mean = %.3f ms, std = %.3f ms" % (np.mean(times), np.std(times))) | ||
| print("Time used: mean = %.3f ms, std = %.3f ms" % (onp.mean(times), onp.std(times))) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Will mxnet np provide the mean and std function?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes. mean and std operators are implemented as mxnet.np.mean and mxnet.np.std |
||
|
|
||
|
|
||
| def main(): | ||
|
|
||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is the 'o' in onp as original as this is original numpy? It will be confusing as np being well known as numpy for short.
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, 'o' in onp is 'official', which is used to distinguish between official numpy and MXNet numpy. Usually, user will do
from mxnet import npand build their models with numpy operators from MXNet. This will provide numpy-compatible coding experience in MXNet for users.