ONNX export: Square and ReduceSum operators#12653
ONNX export: Square and ReduceSum operators#12653vandanavk wants to merge 1 commit intoapache:masterfrom
Conversation
|
@mxnet-label-bot[pr-awaiting-review] |
f228e06 to
3d9372d
Compare
anirudhacharya
left a comment
There was a problem hiding this comment.
this should wait on #12633 and please add tests for these operators.
|
Rebased on top of #12633. @anirudhacharya @Roshrini @zhreshold This PR is ready for review again. |
| mx_axis = attrs.get("axis", None) | ||
| axes = convert_string_to_list(str(mx_axis)) if mx_axis is not None else None | ||
|
|
||
| keepdims = 1 if ("keepdims" in attrs) and \ |
There was a problem hiding this comment.
do you need the first condition ("keepdims" in attrs) here?
There was a problem hiding this comment.
it's an optional parameter. so added this check just in case.
There was a problem hiding this comment.
no, my point was since you are doing attrs.get("keepdims") the first condition is a redundant.
There was a problem hiding this comment.
Ok, considering changing it to
keepdims = attrs.get("keepdims", 0)
keepdims = 1 if keepdims in ["True", "1"] else 0
There was a problem hiding this comment.
I was thinking of this -
keepdims = 1 if attrs.get("keepdims", 0) in ["True", "1"] else 0
but this should be ok i guess.
| axes = convert_string_to_list(str(mx_axis)) if mx_axis is not None else None | ||
|
|
||
| keepdims = 1 if ("keepdims" in attrs) and \ | ||
| attrs.get("keepdims") in ["True", "1"] else 0 |
There was a problem hiding this comment.
shouldn't attrs.get("keepdims") be cast to str before making this comparison? I am curious how it passed the unit test.
There was a problem hiding this comment.
the value has always been "True" or "1" or "0" when it reaches this function. same is the case with similar functions such as Reduce*, Arg*. I will figure out where this change in type occurs and get back to you.
| keepdims=keepdims, | ||
| name=name | ||
| ) | ||
|
|
There was a problem hiding this comment.
nit: could you remove new line here and in line 2150. Or maybe you could have a single return statement after the if else block.
There was a problem hiding this comment.
will change it to a common return statement
| "Pow", | ||
| [input_node_a, power2_name], | ||
| [name], | ||
| name=None |
There was a problem hiding this comment.
was a copy from Pow operator. No reason to have it as None. Will add a name. and change it in Pow as well
There was a problem hiding this comment.
Looks like name wasn't added to Pow because of the error "pow() got an unexpected keyword argument 'name'". Submitting the fix for this in this PR itself (separate commit).
|
LGTM |
|
@anirudhacharya @zhreshold ping for review |
| new_attrs = translation_utils._remove_attributes(new_attrs, ['broadcast']) | ||
| return 'broadcast_power', new_attrs, inputs | ||
| return 'pow', new_attrs, inputs | ||
| mxnet_op = symbol.pow(inputs[0], inputs[1]) |
There was a problem hiding this comment.
why this change? you probably made this change to accommodate op_set version 7 of ONNX. will we support op_set versions older than 7?
There was a problem hiding this comment.
Reverting back to the older commit. Will figure out supporting different op_set versions and work on the Power operator separately.
|
@zhreshold @Roshrini @anirudhacharya review comments addressed. is this PR good to go? |
| inputs = node["inputs"] | ||
|
|
||
| input_node_a_id = kwargs["index_lookup"][inputs[0][0]] | ||
|
|
There was a problem hiding this comment.
nit: in general feel that there are a lot of unnecessary blank lines in this method. also you can remove line 2096 have np.array([2]) directly. but this change is not a blocker
|
|
||
| initializer = kwargs["initializer"] | ||
| power2 = [2] | ||
| np_arr = np.array(power2) |
There was a problem hiding this comment.
what is the point of setting power2 as a temporary variable
There was a problem hiding this comment.
Agree. Re-submitting after editing this.
| axes = convert_string_to_list(str(mx_axis)) if mx_axis is not None else None | ||
|
|
||
| keepdims = attrs.get("keepdims", 0) | ||
| keepdims = 1 if keepdims in ["True", "1"] else 0 |
There was a problem hiding this comment.
Agree. Re-submitting with this change.
Description
Enabling export of Square operator and Sum operator
v2: Added reduce_sum tests based on @anirudhacharya's review
v3: Addressed @Roshrini's and @zhreshold's review comments
Checklist
Essentials
Please feel free to remove inapplicable items for your PR.
Changes
Comments