Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
69 commits
Select commit Hold shift + click to select a range
e8994e4
initial dyn unsqueeze example
Sep 16, 2021
b044395
simplify, properly unpack scalar
Sep 16, 2021
3c0b669
basic tests
Sep 16, 2021
2c3facb
squish bugs -- assign proper types
AndrewZhaoLuo Sep 16, 2021
045ab70
working topi
AndrewZhaoLuo Sep 16, 2021
20f55dd
fix things
AndrewZhaoLuo Sep 16, 2021
e47d78f
temp work
AndrewZhaoLuo Sep 16, 2021
cd15f1a
fix casting to int64
AndrewZhaoLuo Sep 17, 2021
e70f24d
shape encoding method for axis
AndrewZhaoLuo Sep 17, 2021
ba609a8
working shape encoding metric
AndrewZhaoLuo Sep 17, 2021
89fd696
add comment
AndrewZhaoLuo Sep 17, 2021
e21dcc1
move to non-rank encoded axis
AndrewZhaoLuo Sep 17, 2021
7f2a602
failing regime
AndrewZhaoLuo Sep 17, 2021
5a3c47b
fix
AndrewZhaoLuo Sep 17, 2021
c83a1a6
it works!
AndrewZhaoLuo Sep 17, 2021
d5749f3
add test
AndrewZhaoLuo Sep 17, 2021
b46f70b
add comment on shape func
AndrewZhaoLuo Sep 17, 2021
f681dd3
remove unused topi
AndrewZhaoLuo Sep 17, 2021
c73020d
undo some file changes
AndrewZhaoLuo Sep 17, 2021
cf3ace8
more cleanup
AndrewZhaoLuo Sep 17, 2021
2d0d48d
newline
AndrewZhaoLuo Sep 17, 2021
7669645
clean up
AndrewZhaoLuo Sep 17, 2021
14629ec
clean up
AndrewZhaoLuo Sep 17, 2021
2aa5607
enable multiple axis tests
AndrewZhaoLuo Sep 17, 2021
b2cf8ce
move tests to dynamic op
AndrewZhaoLuo Sep 17, 2021
de7320d
Update docs
AndrewZhaoLuo Sep 17, 2021
75d29c5
add converter
AndrewZhaoLuo Sep 17, 2021
385246a
initial dyn unsqueeze example
Sep 16, 2021
533d198
simplify, properly unpack scalar
Sep 16, 2021
5b75ee8
basic tests
Sep 16, 2021
63e9ec9
squish bugs -- assign proper types
AndrewZhaoLuo Sep 16, 2021
72322d0
working topi
AndrewZhaoLuo Sep 16, 2021
9b710ca
fix things
AndrewZhaoLuo Sep 16, 2021
cbf49a5
temp work
AndrewZhaoLuo Sep 16, 2021
85c3d82
fix casting to int64
AndrewZhaoLuo Sep 17, 2021
48e9129
shape encoding method for axis
AndrewZhaoLuo Sep 17, 2021
0013dfd
working shape encoding metric
AndrewZhaoLuo Sep 17, 2021
a187e57
add comment
AndrewZhaoLuo Sep 17, 2021
bac9e50
move to non-rank encoded axis
AndrewZhaoLuo Sep 17, 2021
7decc4d
failing regime
AndrewZhaoLuo Sep 17, 2021
d26ff7a
fix
AndrewZhaoLuo Sep 17, 2021
fe3ce5c
it works!
AndrewZhaoLuo Sep 17, 2021
17c9f65
add test
AndrewZhaoLuo Sep 17, 2021
cc4f30c
add comment on shape func
AndrewZhaoLuo Sep 17, 2021
55b7875
remove unused topi
AndrewZhaoLuo Sep 17, 2021
8d72e8b
undo some file changes
AndrewZhaoLuo Sep 17, 2021
118934f
more cleanup
AndrewZhaoLuo Sep 17, 2021
0d11d7c
newline
AndrewZhaoLuo Sep 17, 2021
4451059
clean up
AndrewZhaoLuo Sep 17, 2021
4310732
clean up
AndrewZhaoLuo Sep 17, 2021
227d912
enable multiple axis tests
AndrewZhaoLuo Sep 17, 2021
720fb8e
move tests to dynamic op
AndrewZhaoLuo Sep 17, 2021
25f8c23
Update docs
AndrewZhaoLuo Sep 17, 2021
0b9ec33
add converter
AndrewZhaoLuo Sep 17, 2021
feb635c
Merge branch 'aluo/onnx/unsqueeze-alt' of github.com:AndrewZhaoLuo/tv…
Sep 17, 2021
12709e7
working tests
Sep 19, 2021
ff39840
add test, remove unneeded file
AndrewZhaoLuo Sep 19, 2021
f7b045f
fix things
Sep 19, 2021
39dba34
more lint
Sep 19, 2021
6caadf3
more lint
Sep 19, 2021
a9f5117
pick things
AndrewZhaoLuo Sep 19, 2021
c0ca64b
Merge branch 'aluo/onnx/unsqueeze-alt' of github.com:AndrewZhaoLuo/tv…
AndrewZhaoLuo Sep 19, 2021
39f729f
disable opencl tests
AndrewZhaoLuo Sep 19, 2021
4f6725e
unsqueeze tests
Sep 19, 2021
2bf1fb0
clean up
AndrewZhaoLuo Sep 20, 2021
3f5e26d
Merge branch 'main' into aluo/onnx/unsqueeze-alt
AndrewZhaoLuo Sep 20, 2021
0b58530
dyn stuff
AndrewZhaoLuo Sep 21, 2021
7e2ebac
add num_newaxis
AndrewZhaoLuo Sep 21, 2021
dc17ada
Merge branch 'main' into aluo/onnx/unsqueeze-alt
AndrewZhaoLuo Sep 21, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions include/tvm/relay/attrs/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,18 @@ struct ExpandDimsAttrs : public tvm::AttrsNode<ExpandDimsAttrs> {
}
}; // struct ExpandDimsAttrs

/*! \brief Attributes used in dynamic expand_dims operators */
struct DynExpandDimsAttrs : public tvm::AttrsNode<DynExpandDimsAttrs> {
int num_newaxis;

TVM_DECLARE_ATTRS(DynExpandDimsAttrs, "relay.attrs.DynExpandDimsAttrs") {
TVM_ATTR_FIELD(num_newaxis)
.describe("Number of axes to be inserted. Should be >= 0.")
.set_lower_bound(0)
.set_default(1);
}
}; // struct ExpandDimsAttrs

/*! \brief Attributes used in concatenate operators */
struct ConcatenateAttrs : public tvm::AttrsNode<ConcatenateAttrs> {
int axis;
Expand Down
20 changes: 20 additions & 0 deletions python/tvm/relay/frontend/onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -1462,6 +1462,26 @@ def _impl_v1(cls, inputs, attr, params):
inputs[0] = _op.expand_dims(inputs[0], axis=axis, num_newaxis=1)
return inputs[0]

@classmethod
def _impl_v12(cls, inputs, attr, params):
rank_input = len(infer_type(inputs[0]).checked_type.shape)
num_new_axis = int(infer_type(inputs[1]).checked_type.shape[0])
axes = relay.split(inputs[1], num_new_axis).astuple()

result = inputs[0]

# TODO (AndrewZhaoLuo): investigate performance issues with consecutive
# dynamic expand_dims on non-llvm targets.
for i in range(num_new_axis):
axis = relay.TupleGetItem(axes, i)
# Unpack scalar
axis = relay.reshape(axis, [])
axis = relay.If(
axis >= relay.const(0, "int64"), axis, axis + relay.const(rank_input, "int64")
)
result = _op.expand_dims(result, axis)
return result

Comment on lines +1474 to +1484
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Again, I think this should be doable as one call?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This might be doable if we change the interface to accept a list of sorted axis but it'll be quite a bit more complicated.


class Split(OnnxOpConverter):
"""Operator converter for Split."""
Expand Down
38 changes: 38 additions & 0 deletions python/tvm/relay/op/dyn/_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,12 @@

from tvm.runtime import convert
from tvm.te.hybrid import script

from .. import op as _reg

_reg.register_broadcast_schedule("dyn.broadcast_to")
_reg.register_injective_schedule("dyn.reshape")
_reg.register_injective_schedule("dyn.expand_dims")
_reg.register_broadcast_schedule("dyn.tile")
_reg.register_injective_schedule("dyn.one_hot")
_reg.register_injective_schedule("dyn.full")
Expand Down Expand Up @@ -89,6 +91,42 @@ def dynamic_reshape_shape_func(attrs, inputs, out_ndims):
return [_reshape_shape_func_input_data(*inputs, out_ndims[0])]


@script
def _expand_dims_shape_func_input_data(data, axis, ndims, num_newaxis):
out = output_tensor((ndims,), "int64")

for i in const_range(ndims):
if i < axis:
# We multiply by a check (i < len(data.shape)) to avoid
# a constant folding mechanism leading to an overflow
out[i] = int64(data.shape[i * (i < len(data.shape))])
elif i - num_newaxis < axis:
out[i] = int64(1)
else:
out[i] = int64(
# We can't use axis in indices as it is not constant but we can
# use negative indices (kind of, have to manually do it)
data.shape[
(i - num_newaxis) * (i - num_newaxis >= 0)
+ (i - num_newaxis + len(data.shape)) * (i - num_newaxis < 0)
]
)

return out


@_reg.register_shape_func("dyn.expand_dims", [True, True])
def dynamic_expand_dims_shape_func(attrs, inputs, out_ndims):
return [
_expand_dims_shape_func_input_data(
inputs[0],
inputs[1],
out_ndims[0],
convert(attrs.num_newaxis),
)
]


@script
def _tile_shape_func(data, reps, ndim, tndim, rndim):
out = output_tensor((tndim,), "int64")
Expand Down
10 changes: 8 additions & 2 deletions python/tvm/relay/op/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ def expand_dims(data, axis, num_newaxis=1):
data : relay.Expr
The input data to the operator.

axis : int
axis : Union[int, Expr]
The axis at which the input array is expanded.
Should lie in range `[-data.ndim - 1, data.ndim]`.
If `axis < 0`, it is the first axis inserted;
Expand All @@ -110,7 +110,13 @@ def expand_dims(data, axis, num_newaxis=1):
result : relay.Expr
The reshaped result.
"""
return _make.expand_dims(data, axis, num_newaxis)
if isinstance(axis, int):
return _make.expand_dims(data, axis, num_newaxis)
if isinstance(axis, Expr):
# TODO (AndrewZhaoLuo): investigate performance issues with consecutive
# dynamic expand_dims on non-llvm targets.
return _dyn_make.expand_dims(data, axis, num_newaxis)
raise ValueError(f"Unknown type for axis: {type(axis)}")


def transpose(data, axes=None):
Expand Down
74 changes: 74 additions & 0 deletions src/relay/op/dyn/tensor/transform.cc
Original file line number Diff line number Diff line change
Expand Up @@ -618,6 +618,80 @@ RELAY_REGISTER_OP("dyn.sparse_to_dense")
.set_attr<FInferCorrectLayout>("FInferCorrectLayout", ElemwiseArbitraryLayout)
.set_attr<FTVMCompute>("FTVMCompute", SparseToDenseCompute);

/* relay.dyn.unsqueeze */
TVM_REGISTER_NODE_TYPE(DynExpandDimsAttrs);

bool ExpandDimsRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
const TypeReporter& reporter) {
ICHECK_EQ(num_inputs, 2);
const auto* data_type = types[0].as<TensorTypeNode>();
if (data_type == nullptr) {
ICHECK(types[0].as<IncompleteTypeNode>())
<< "expand_dims: expect input type to be TensorType but get " << types[0];
return false;
}

const auto* param = attrs.as<DynExpandDimsAttrs>();

// We don't know the output shape until we see the value of the axis input
int ndim = data_type->shape.size();
Array<IndexExpr> oshape(ndim + param->num_newaxis, Any());

const auto* axis_type = types[1].as<TensorTypeNode>();
ICHECK(axis_type->shape.size() == 0) << "Axis should be a scalar got shape " << axis_type->shape;

// Set output shape
reporter->Assign(types[2], TensorType(oshape, data_type->dtype));
return true;
}

Array<te::Tensor> ExpandDimsCompute(const Attrs& attrs, const Array<te::Tensor>& inputs,
const Type& out_type) {
// inputs = [Input tensor, axis to expand]
ICHECK_EQ(inputs.size(), 2);

const auto* param = attrs.as<DynExpandDimsAttrs>();

Array<IndexExpr> ishape = inputs[0]->shape;
const TensorTypeNode* out_ttype = out_type.as<TensorTypeNode>();
int ndim_out = out_ttype->shape.size();
int ndim_in = ishape.size();
ICHECK_EQ(ndim_in + param->num_newaxis, ndim_out);

Array<IndexExpr> newshape;
for (auto val : out_ttype->shape) {
// These vars will be populated by the VM executor with the results
// of the shape_func for the op.
newshape.push_back(val.as<tir::AnyNode>()->ToVar());
}

return {topi::reshape(inputs[0], newshape)};
}

Expr MakeExpandDims(Expr data, Expr axis_tensor, int num_newaxis) {
auto attrs = make_object<DynExpandDimsAttrs>();
attrs->num_newaxis = num_newaxis;
static const Op& op = Op::Get("dyn.expand_dims");
return Call(op, {data, axis_tensor}, Attrs(attrs), {});
}

TVM_REGISTER_GLOBAL("relay.op.dyn._make.expand_dims").set_body_typed(MakeExpandDims);

RELAY_REGISTER_OP("dyn.expand_dims")
.describe(R"code(Insert one new axis at the position given by `axis`

- **data**: The input data to the operator.
- **axis**: The axis to insert a new dimension

)code" TVM_ADD_FILELINE)
.set_num_inputs(2)
.add_argument("data", "Tensor", "The input tensor.")
.add_argument("axis", "Tensor", "The axis to insert at a dimension.")
.set_support_level(3)
.add_type_rel("DynamicExpandDims", ExpandDimsRel)
.set_attr<FTVMCompute>("FTVMCompute", ExpandDimsCompute)
.set_attr<TOpPattern>("TOpPattern", kInjective);

} // namespace dyn
} // namespace relay
} // namespace tvm
11 changes: 4 additions & 7 deletions tests/python/frontend/onnx/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -5015,16 +5015,13 @@ def verify_eyelike(indata):
"test_training_dropout_mask",
"test_training_dropout_zero_ratio",
"test_training_dropout_zero_ratio_mask",
"test_unique_sorted_with_axis",
"test_unique_sorted_with_axis_3d",
"test_unique_sorted_with_negative_axis",
"test_unsqueeze_axis_0",
"test_unsqueeze_axis_1",
"test_unsqueeze_axis_2",
"test_unsqueeze_negative_axes",
# These unsqueeze tests work, but take 2+ hrs to run
"test_unsqueeze_three_axes",
"test_unsqueeze_two_axes",
"test_unsqueeze_unsorted_axes",
Comment on lines 5019 to 5021
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add these to device-specific skips below?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm actually all targets are really slow it seems. The LLVM target takes > 3 minutes for test_unsqueeze_two_axes.

"test_unique_sorted_with_axis",
"test_unique_sorted_with_axis_3d",
"test_unique_sorted_with_negative_axis",
"test_upsample_nearest",
]

Expand Down
34 changes: 31 additions & 3 deletions tests/python/relay/dyn/test_dynamic_op_level3.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,10 @@
import numpy as np
import pytest
import tvm
from tvm import te
from tvm import relay
import tvm.testing
from tvm import relay, te
from tvm.relay import create_executor, transform
from tvm.relay.testing import check_grad, run_infer_type
import tvm.testing


def verify_func(func, data, ref_res, target_device=tvm.testing.enabled_targets()):
Expand Down Expand Up @@ -93,6 +92,35 @@ def verify_reshape(shape, newshape, oshape):
verify_reshape((4, 7), (2, 7, 2), (2, 7, 2))


@tvm.testing.uses_gpu
def test_dyn_expand_dims():
def verify_expand_dims(
dshape, dtype, oshape, axis, num_newaxis, target_device=tvm.testing.enabled_targets()
):
# Use 1 to avoid issues with invalid buffer sizes
x = relay.Var("x", relay.TensorType(dshape, dtype))
y = relay.var("axis", shape=[], dtype="int64")
z = relay.expand_dims(x, axis=y, num_newaxis=num_newaxis)
func = relay.Function([x, y], z)

data_np = np.random.uniform(size=dshape).astype(dtype)
axis_np = np.array(axis).astype("int64")
ref_res = data_np.reshape(oshape)
verify_func(func, [data_np, axis_np], ref_res, target_device=target_device)

for dtype in ["float16", "float32"]:
verify_expand_dims((2, 2), dtype, (2, 2, 1), 2, 1)
verify_expand_dims((2, 2), dtype, (2, 1, 2), 1, 1)
verify_expand_dims((2, 2), dtype, (1, 2, 2), 0, 1)

# TODO (AndrewZhaoLuo): investigate why runtimes in non-llvm are extremely slow
# for multiple new axis
llvm_target_only = [x for x in tvm.testing.enabled_targets() if "llvm" in x]
verify_expand_dims((2, 2), dtype, (2, 2, 1, 1), 2, 2, target_device=llvm_target_only)
verify_expand_dims((2, 2), dtype, (2, 1, 1, 1, 2), 1, 3, target_device=llvm_target_only)
verify_expand_dims((2, 2), dtype, (1, 1, 1, 1, 2, 2), 0, 4, target_device=llvm_target_only)


@tvm.testing.uses_gpu
def test_dyn_tile():
def verify_tile(dshape, reps):
Expand Down