From 8a62bae0c1213f63b72525ea82181bcd4dea7b5c Mon Sep 17 00:00:00 2001 From: An Wang Date: Fri, 10 Sep 2021 13:18:06 -0700 Subject: [PATCH 1/7] einsum --- include/tvm/relay/attrs/transform.h | 9 +++ python/tvm/relay/frontend/onnx.py | 10 +++ python/tvm/relay/op/_transform.py | 6 ++ python/tvm/relay/op/strategy/cuda.py | 12 ++++ python/tvm/relay/op/strategy/generic.py | 21 ++++++ python/tvm/relay/op/tensor.py | 23 +++++++ python/tvm/topi/generic/search.py | 16 +++++ src/relay/op/tensor/transform.cc | 80 ++++++++++++++++++++++ tests/python/frontend/onnx/test_forward.py | 5 -- 9 files changed, 177 insertions(+), 5 deletions(-) diff --git a/include/tvm/relay/attrs/transform.h b/include/tvm/relay/attrs/transform.h index a8317e1e51ad..b5815dc59286 100644 --- a/include/tvm/relay/attrs/transform.h +++ b/include/tvm/relay/attrs/transform.h @@ -489,6 +489,15 @@ struct UniqueAttrs : public tvm::AttrsNode { } }; // struct UniqueAttrs +/*! \brief Attributes used in einsum operator */ +struct EinsumAttrs : public tvm::AttrsNode { + String equation; + + TVM_DECLARE_ATTRS(EinsumAttrs, "relay.attrs.EinsumAttrs") { + TVM_ATTR_FIELD(equation).describe("The einsum expression string"); + } +}; + } // namespace relay } // namespace tvm #endif // TVM_RELAY_ATTRS_TRANSFORM_H_ diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py index 48a5b48b37d3..1b32ea4e9e88 100644 --- a/python/tvm/relay/frontend/onnx.py +++ b/python/tvm/relay/frontend/onnx.py @@ -3517,6 +3517,15 @@ def _impl_v11(cls, inputs, attr, params): return _expr.TupleWrapper(_expr.Tuple([unique_vals, indices, inverse_indices, counts]), 4) +class Einsum(OnnxOpConverter): + """Operator converter for Einsum""" + + @classmethod + def _impl_v12(cls, inputs, attr, params): + equation = attr["equation"].decode("utf-8") + return _op.einsum(inputs, equation) + + class RandomUniform(OnnxOpConverter): """Operator converter for random_uniform""" @@ -3763,6 +3772,7 @@ def _get_convert_map(opset): "Range": Range.get_converter(opset), "CumSum": CumSum.get_converter(opset), "Unique": Unique.get_converter(opset), + "Einsum": Einsum.get_converter(opset), # defs/control_flow "Loop": Loop.get_converter(opset), "If": If.get_converter(opset), diff --git a/python/tvm/relay/op/_transform.py b/python/tvm/relay/op/_transform.py index 7474f8e22bb5..efb0273617a6 100644 --- a/python/tvm/relay/op/_transform.py +++ b/python/tvm/relay/op/_transform.py @@ -182,6 +182,12 @@ def compute_unique(attrs, inputs, output_type): _reg.register_strategy("invert_permutation", strategy.invert_permutation_strategy) _reg.register_shape_func("invert_permutation", False, elemwise_shape_func) + +# einsum +_reg.register_strategy("einsum", strategy.einsum_strategy) +_reg.register_shape_func("einsum", False, elemwise_shape_func) + + ##################### # Shape functions # ##################### diff --git a/python/tvm/relay/op/strategy/cuda.py b/python/tvm/relay/op/strategy/cuda.py index ba47ae7bc4f1..a03f22700931 100644 --- a/python/tvm/relay/op/strategy/cuda.py +++ b/python/tvm/relay/op/strategy/cuda.py @@ -1210,3 +1210,15 @@ def invert_permutation_strategy_cuda(attrs, inputs, out_type, target): name="invert_permutation.cuda", ) return strategy + + +@einsum_strategy.register(["cuda", "gpu"]) +def einsum_strategy_cuda(attrs, inputs, out_type, target): + """einsum cuda strategy""" + strategy = _op.OpStrategy() + strategy.add_implementation( + wrap_compute_einsum(topi.cuda.einsum), + wrap_topi_schedule(topi.cuda.schedule_einsum), + name="einsum.cuda", + ) + return strategy diff --git a/python/tvm/relay/op/strategy/generic.py b/python/tvm/relay/op/strategy/generic.py index 9c756f201721..2822585caeaf 100644 --- a/python/tvm/relay/op/strategy/generic.py +++ b/python/tvm/relay/op/strategy/generic.py @@ -1669,3 +1669,24 @@ def invert_permutation_strategy(attrs, inputs, out_type, target): name="invert_permutation.generic", ) return strategy + + +def wrap_compute_einsum(topi_compute): + """Wrap einsum topi compute""" + + def _compute_einsum(attrs, inputs, _): + return [topi_compute(attrs.equation, *inputs)] + + return _compute_einsum + + +@override_native_generic_func("einsum_strategy") +def einsum_strategy(attrs, inputs, out_type, target): + """einsum generic strategy""" + strategy = _op.OpStrategy() + strategy.add_implementation( + wrap_compute_einsum(topi.einsum), + wrap_topi_schedule(topi.generic.schedule_einsum), + name="einsum.generic", + ) + return strategy diff --git a/python/tvm/relay/op/tensor.py b/python/tvm/relay/op/tensor.py index a38a23064d6f..e47928919ce1 100644 --- a/python/tvm/relay/op/tensor.py +++ b/python/tvm/relay/op/tensor.py @@ -1104,6 +1104,29 @@ def concatenate(data, axis): return _make.concatenate(Tuple(data), axis) +def einsum(data, equation): + """Evaluates the Einstein summation convention on data + + Parameters + ---------- + data : Union(List[relay.Expr], Tuple[relay.Expr]) + A list of tensors. + equation : str + The einsum expression string. + + Returns + ------- + result : relay.Expr + The output tensor from the einsum op. + """ + data = list(data) + if not data: + raise ValueError("relay.einsum requires data to be non-empty.") + if not isinstance(equation, str): + raise ValueError("einsum `equation` must be a str") + return _make.einsum(Tuple(data), equation) + + def stack(data, axis): """Join a sequence of arrays along a new axis. diff --git a/python/tvm/topi/generic/search.py b/python/tvm/topi/generic/search.py index f458ee7bc782..068551006e37 100644 --- a/python/tvm/topi/generic/search.py +++ b/python/tvm/topi/generic/search.py @@ -86,3 +86,19 @@ def schedule_unique(outs): The computation schedule for the op. """ return _default_schedule(outs, False) + + +def schedule_einsum(outs): + """Schedule for einsum operator. + + Parameters + ---------- + outs: Array of Tensor + The computation graph description of einsum. + + Returns + ------- + s: Schedule + The computation schedule for the op. + """ + return _default_schedule(outs, False) diff --git a/src/relay/op/tensor/transform.cc b/src/relay/op/tensor/transform.cc index 3781107eeee1..e4fc0083ccd1 100644 --- a/src/relay/op/tensor/transform.cc +++ b/src/relay/op/tensor/transform.cc @@ -33,6 +33,7 @@ #include #include #include +#include #include #include #include @@ -2431,6 +2432,85 @@ RELAY_REGISTER_OP("broadcast_to_like") .set_attr("FTVMCompute", BroadCastToLikeCompute) .set_attr("TOpPattern", kBroadcast); +// relay.einsum +TVM_REGISTER_NODE_TYPE(EinsumAttrs); + +bool EinsumRel(const Array& types, int num_inputs, const Attrs& attrs, + const TypeReporter& reporter) { + // Check attrs + const EinsumAttrs* param = attrs.as(); + if (param == nullptr) { + reporter->GetDiagCtx().EmitFatal(Diagnostic::Error(reporter->GetSpan()) + << "the call attributes are not defined"); + return false; + } + + // types: [data, result] + ICHECK_EQ(types.size(), 2) << "the arity of einsum is 2, not " << types.size(); + + // Check input type is a tuple. + const auto* tensor_tuple = types[0].as(); + if (tensor_tuple == nullptr) { + reporter->GetDiagCtx().EmitFatal( + Diagnostic::Error(reporter->GetSpan()) + << "einsum requires a tuple of tensors as the first argument, found " + << PrettyPrint(types[0])); + return false; + } + + // Check the input tuple consists of tensors with consistent dtype. + const auto& first = Downcast(tensor_tuple->fields[0]); + const DataType dtype = first->dtype; + std::vector> input_shapes; + for (const Type& ele : tensor_tuple->fields) { + if (ele.as()) { + return false; + } + + const auto& e = Downcast(ele); + + const DataType& e_dtype = e->dtype; + if (e_dtype != dtype) { + throw Error("relay.einsum requires all tensors have the same dtype"); + } + input_shapes.push_back(e->shape); + } + + // Calculate output shape + Array oshape = topi::NumpyEinsumShape(param->equation, input_shapes); + + auto rtype = TensorType(oshape, dtype); + reporter->Assign(types[1], rtype); + return true; +} + +Array EinsumCompute(const Attrs& attrs, const Array& inputs, + const Type& out_type) { + const EinsumAttrs* param = attrs.as(); + ICHECK(param != nullptr); + return Array{topi::einsum(param->equation, inputs)}; +} + +Expr MakeEinsum(Expr data, String equation) { + auto attrs = make_object(); + attrs->equation = std::move(equation); + static const Op& op = Op::Get("einsum"); + return Call(op, {data}, Attrs(attrs), {}); +} + +TVM_REGISTER_GLOBAL("relay.op._make.einsum").set_body_typed(MakeEinsum); + +RELAY_REGISTER_OP("einsum") + .describe(R"doc(Evaluates the Einstein summation convention +on the operands)doc" TVM_ADD_FILELINE) + .set_attrs_type() + .set_num_inputs(1) + .add_argument("data", "Tuple of Tensors", "The input list of tensors.") + .set_support_level(11) + .add_type_rel("Einsum", EinsumRel) + .set_attr("FTVMCompute", EinsumCompute) + .set_attr("TOpPattern", kInjective); + // Adapter function to make int array. Array GetIntArray(Array arr) { for (size_t i = 0; i < arr.size(); ++i) { diff --git a/tests/python/frontend/onnx/test_forward.py b/tests/python/frontend/onnx/test_forward.py index 6fa4fcf47ccb..f9094e94a3ee 100644 --- a/tests/python/frontend/onnx/test_forward.py +++ b/tests/python/frontend/onnx/test_forward.py @@ -4739,11 +4739,6 @@ def verify_eyelike(indata): "test_dropout_default_mask", "test_dropout_default_mask_ratio", "test_dropout_default_ratio", - "test_einsum_batch_diagonal", - "test_einsum_batch_matmul", - "test_einsum_inner_prod", - "test_einsum_sum", - "test_einsum_transpose", "test_greater_equal", "test_greater_equal_bcast", "test_if_seq", From 1dae1c36aaafa70aec647a4e40982bf3ee7317af Mon Sep 17 00:00:00 2001 From: An Wang Date: Fri, 10 Sep 2021 14:03:13 -0700 Subject: [PATCH 2/7] address review --- include/tvm/relay/attrs/transform.h | 4 ++-- python/tvm/relay/op/_transform.py | 1 - src/relay/op/tensor/transform.cc | 4 ++-- 3 files changed, 4 insertions(+), 5 deletions(-) diff --git a/include/tvm/relay/attrs/transform.h b/include/tvm/relay/attrs/transform.h index b5815dc59286..e3f9bad17ef5 100644 --- a/include/tvm/relay/attrs/transform.h +++ b/include/tvm/relay/attrs/transform.h @@ -475,7 +475,7 @@ struct ScanopAttrs : public tvm::AttrsNode { .describe("The first element is not included") .set_default(Bool(false)); } -}; +}; // struct ScanopAttrs /*! \brief Attributes used in unique operator */ struct UniqueAttrs : public tvm::AttrsNode { @@ -496,7 +496,7 @@ struct EinsumAttrs : public tvm::AttrsNode { TVM_DECLARE_ATTRS(EinsumAttrs, "relay.attrs.EinsumAttrs") { TVM_ATTR_FIELD(equation).describe("The einsum expression string"); } -}; +}; // struct EinsumAttrs } // namespace relay } // namespace tvm diff --git a/python/tvm/relay/op/_transform.py b/python/tvm/relay/op/_transform.py index efb0273617a6..117f32e64e9f 100644 --- a/python/tvm/relay/op/_transform.py +++ b/python/tvm/relay/op/_transform.py @@ -185,7 +185,6 @@ def compute_unique(attrs, inputs, output_type): # einsum _reg.register_strategy("einsum", strategy.einsum_strategy) -_reg.register_shape_func("einsum", False, elemwise_shape_func) ##################### diff --git a/src/relay/op/tensor/transform.cc b/src/relay/op/tensor/transform.cc index e4fc0083ccd1..70ad0b3e26e3 100644 --- a/src/relay/op/tensor/transform.cc +++ b/src/relay/op/tensor/transform.cc @@ -2436,7 +2436,7 @@ RELAY_REGISTER_OP("broadcast_to_like") TVM_REGISTER_NODE_TYPE(EinsumAttrs); bool EinsumRel(const Array& types, int num_inputs, const Attrs& attrs, - const TypeReporter& reporter) { + const TypeReporter& reporter) { // Check attrs const EinsumAttrs* param = attrs.as(); if (param == nullptr) { @@ -2485,7 +2485,7 @@ bool EinsumRel(const Array& types, int num_inputs, const Attrs& attrs, } Array EinsumCompute(const Attrs& attrs, const Array& inputs, - const Type& out_type) { + const Type& out_type) { const EinsumAttrs* param = attrs.as(); ICHECK(param != nullptr); return Array{topi::einsum(param->equation, inputs)}; From 1d03f105315de8db4166a37421a49a92a5f98f48 Mon Sep 17 00:00:00 2001 From: An Wang Date: Mon, 13 Sep 2021 11:23:29 -0700 Subject: [PATCH 3/7] move files around --- python/tvm/relay/op/__init__.py | 1 + python/tvm/relay/op/_math.py | 22 ++++++ python/tvm/relay/op/_transform.py | 4 - python/tvm/topi/generic/__init__.py | 1 + python/tvm/topi/generic/math.py | 34 ++++++++ python/tvm/topi/generic/search.py | 16 ---- src/relay/op/tensor/math.cc | 115 ++++++++++++++++++++++++++++ src/relay/op/tensor/transform.cc | 80 ------------------- 8 files changed, 173 insertions(+), 100 deletions(-) create mode 100644 python/tvm/relay/op/_math.py create mode 100644 python/tvm/topi/generic/math.py create mode 100644 src/relay/op/tensor/math.cc diff --git a/python/tvm/relay/op/__init__.py b/python/tvm/relay/op/__init__.py index 2e509a111c4a..825bd1f627ca 100644 --- a/python/tvm/relay/op/__init__.py +++ b/python/tvm/relay/op/__init__.py @@ -54,6 +54,7 @@ from . import _transform from . import _reduce from . import _algorithm +from . import _math def _register_op_make(): diff --git a/python/tvm/relay/op/_math.py b/python/tvm/relay/op/_math.py new file mode 100644 index 000000000000..ff74fafcef75 --- /dev/null +++ b/python/tvm/relay/op/_math.py @@ -0,0 +1,22 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Backend compiler related feature registration""" +from . import op as _reg +from . import strategy + +# einsum +_reg.register_strategy("einsum", strategy.einsum_strategy) diff --git a/python/tvm/relay/op/_transform.py b/python/tvm/relay/op/_transform.py index 117f32e64e9f..0284d2483ce5 100644 --- a/python/tvm/relay/op/_transform.py +++ b/python/tvm/relay/op/_transform.py @@ -183,10 +183,6 @@ def compute_unique(attrs, inputs, output_type): _reg.register_shape_func("invert_permutation", False, elemwise_shape_func) -# einsum -_reg.register_strategy("einsum", strategy.einsum_strategy) - - ##################### # Shape functions # ##################### diff --git a/python/tvm/topi/generic/__init__.py b/python/tvm/topi/generic/__init__.py index cc64abab8ed8..021f9a1bbe1d 100644 --- a/python/tvm/topi/generic/__init__.py +++ b/python/tvm/topi/generic/__init__.py @@ -39,3 +39,4 @@ from .sort import * from .search import * from .image import * +from .math import * diff --git a/python/tvm/topi/generic/math.py b/python/tvm/topi/generic/math.py new file mode 100644 index 000000000000..3af6cd16a374 --- /dev/null +++ b/python/tvm/topi/generic/math.py @@ -0,0 +1,34 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Generic math operators""" +from .default import default_schedule as _default_schedule + + +def schedule_einsum(outs): + """Schedule for einsum operator. + + Parameters + ---------- + outs: Array of Tensor + The computation graph description of einsum. + + Returns + ------- + s: Schedule + The computation schedule for the op. + """ + return _default_schedule(outs, False) diff --git a/python/tvm/topi/generic/search.py b/python/tvm/topi/generic/search.py index 068551006e37..f458ee7bc782 100644 --- a/python/tvm/topi/generic/search.py +++ b/python/tvm/topi/generic/search.py @@ -86,19 +86,3 @@ def schedule_unique(outs): The computation schedule for the op. """ return _default_schedule(outs, False) - - -def schedule_einsum(outs): - """Schedule for einsum operator. - - Parameters - ---------- - outs: Array of Tensor - The computation graph description of einsum. - - Returns - ------- - s: Schedule - The computation schedule for the op. - """ - return _default_schedule(outs, False) diff --git a/src/relay/op/tensor/math.cc b/src/relay/op/tensor/math.cc new file mode 100644 index 000000000000..246fba62cc66 --- /dev/null +++ b/src/relay/op/tensor/math.cc @@ -0,0 +1,115 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file math.cc + * \brief Math operators. + */ +#include +#include +#include + +#include "../make_op.h" +#include "../op_common.h" +#include "../type_relations.h" + +namespace tvm { +namespace relay { + +// relay.einsum +TVM_REGISTER_NODE_TYPE(EinsumAttrs); + +bool EinsumRel(const Array& types, int num_inputs, const Attrs& attrs, + const TypeReporter& reporter) { + // Check attrs + const EinsumAttrs* param = attrs.as(); + if (param == nullptr) { + reporter->GetDiagCtx().EmitFatal(Diagnostic::Error(reporter->GetSpan()) + << "the call attributes are not defined"); + return false; + } + + // types: [data, result] + ICHECK_EQ(types.size(), 2) << "the arity of einsum is 2, not " << types.size(); + + // Check input type is a tuple. + const auto* tensor_tuple = types[0].as(); + if (tensor_tuple == nullptr) { + reporter->GetDiagCtx().EmitFatal( + Diagnostic::Error(reporter->GetSpan()) + << "einsum requires a tuple of tensors as the first argument, found " + << PrettyPrint(types[0])); + return false; + } + + // Check the input tuple consists of tensors with consistent dtype. + const auto& first = Downcast(tensor_tuple->fields[0]); + const DataType dtype = first->dtype; + std::vector> input_shapes; + for (const Type& ele : tensor_tuple->fields) { + if (ele.as()) { + return false; + } + + const auto& e = Downcast(ele); + + const DataType& e_dtype = e->dtype; + if (e_dtype != dtype) { + throw Error("relay.einsum requires all tensors have the same dtype"); + } + input_shapes.push_back(e->shape); + } + + // Calculate output shape + Array oshape = topi::NumpyEinsumShape(param->equation, input_shapes); + + auto rtype = TensorType(oshape, dtype); + reporter->Assign(types[1], rtype); + return true; +} + +Array EinsumCompute(const Attrs& attrs, const Array& inputs, + const Type& out_type) { + const EinsumAttrs* param = attrs.as(); + ICHECK(param != nullptr); + return Array{topi::einsum(param->equation, inputs)}; +} + +Expr MakeEinsum(Expr data, String equation) { + auto attrs = make_object(); + attrs->equation = std::move(equation); + static const Op& op = Op::Get("einsum"); + return Call(op, {data}, Attrs(attrs), {}); +} + +TVM_REGISTER_GLOBAL("relay.op._make.einsum").set_body_typed(MakeEinsum); + +RELAY_REGISTER_OP("einsum") + .describe(R"doc(Evaluates the Einstein summation convention +on the operands)doc" TVM_ADD_FILELINE) + .set_attrs_type() + .set_num_inputs(1) + .add_argument("data", "Tuple of Tensors", "The input list of tensors.") + .set_support_level(11) + .add_type_rel("Einsum", EinsumRel) + .set_attr("FTVMCompute", EinsumCompute) + .set_attr("TOpPattern", kInjective); + +} // namespace relay +} // namespace tvm diff --git a/src/relay/op/tensor/transform.cc b/src/relay/op/tensor/transform.cc index 70ad0b3e26e3..3781107eeee1 100644 --- a/src/relay/op/tensor/transform.cc +++ b/src/relay/op/tensor/transform.cc @@ -33,7 +33,6 @@ #include #include #include -#include #include #include #include @@ -2432,85 +2431,6 @@ RELAY_REGISTER_OP("broadcast_to_like") .set_attr("FTVMCompute", BroadCastToLikeCompute) .set_attr("TOpPattern", kBroadcast); -// relay.einsum -TVM_REGISTER_NODE_TYPE(EinsumAttrs); - -bool EinsumRel(const Array& types, int num_inputs, const Attrs& attrs, - const TypeReporter& reporter) { - // Check attrs - const EinsumAttrs* param = attrs.as(); - if (param == nullptr) { - reporter->GetDiagCtx().EmitFatal(Diagnostic::Error(reporter->GetSpan()) - << "the call attributes are not defined"); - return false; - } - - // types: [data, result] - ICHECK_EQ(types.size(), 2) << "the arity of einsum is 2, not " << types.size(); - - // Check input type is a tuple. - const auto* tensor_tuple = types[0].as(); - if (tensor_tuple == nullptr) { - reporter->GetDiagCtx().EmitFatal( - Diagnostic::Error(reporter->GetSpan()) - << "einsum requires a tuple of tensors as the first argument, found " - << PrettyPrint(types[0])); - return false; - } - - // Check the input tuple consists of tensors with consistent dtype. - const auto& first = Downcast(tensor_tuple->fields[0]); - const DataType dtype = first->dtype; - std::vector> input_shapes; - for (const Type& ele : tensor_tuple->fields) { - if (ele.as()) { - return false; - } - - const auto& e = Downcast(ele); - - const DataType& e_dtype = e->dtype; - if (e_dtype != dtype) { - throw Error("relay.einsum requires all tensors have the same dtype"); - } - input_shapes.push_back(e->shape); - } - - // Calculate output shape - Array oshape = topi::NumpyEinsumShape(param->equation, input_shapes); - - auto rtype = TensorType(oshape, dtype); - reporter->Assign(types[1], rtype); - return true; -} - -Array EinsumCompute(const Attrs& attrs, const Array& inputs, - const Type& out_type) { - const EinsumAttrs* param = attrs.as(); - ICHECK(param != nullptr); - return Array{topi::einsum(param->equation, inputs)}; -} - -Expr MakeEinsum(Expr data, String equation) { - auto attrs = make_object(); - attrs->equation = std::move(equation); - static const Op& op = Op::Get("einsum"); - return Call(op, {data}, Attrs(attrs), {}); -} - -TVM_REGISTER_GLOBAL("relay.op._make.einsum").set_body_typed(MakeEinsum); - -RELAY_REGISTER_OP("einsum") - .describe(R"doc(Evaluates the Einstein summation convention -on the operands)doc" TVM_ADD_FILELINE) - .set_attrs_type() - .set_num_inputs(1) - .add_argument("data", "Tuple of Tensors", "The input list of tensors.") - .set_support_level(11) - .add_type_rel("Einsum", EinsumRel) - .set_attr("FTVMCompute", EinsumCompute) - .set_attr("TOpPattern", kInjective); - // Adapter function to make int array. Array GetIntArray(Array arr) { for (size_t i = 0; i < arr.size(); ++i) { From 653a9aab5fe51c0376e986f7018cb734d90f46f9 Mon Sep 17 00:00:00 2001 From: An Wang Date: Tue, 14 Sep 2021 14:04:03 -0700 Subject: [PATCH 4/7] use generic topi op --- python/tvm/relay/op/strategy/cuda.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/tvm/relay/op/strategy/cuda.py b/python/tvm/relay/op/strategy/cuda.py index a03f22700931..88c930ead282 100644 --- a/python/tvm/relay/op/strategy/cuda.py +++ b/python/tvm/relay/op/strategy/cuda.py @@ -1217,8 +1217,8 @@ def einsum_strategy_cuda(attrs, inputs, out_type, target): """einsum cuda strategy""" strategy = _op.OpStrategy() strategy.add_implementation( - wrap_compute_einsum(topi.cuda.einsum), - wrap_topi_schedule(topi.cuda.schedule_einsum), + wrap_compute_einsum(topi.einsum), + wrap_topi_schedule(topi.generic.schedule_extern), name="einsum.cuda", ) return strategy From 64a4f4af8c4f0aae45e5f77a92ced125b77d5df6 Mon Sep 17 00:00:00 2001 From: An Wang Date: Tue, 14 Sep 2021 14:05:26 -0700 Subject: [PATCH 5/7] TODO comment --- python/tvm/relay/op/strategy/cuda.py | 1 + 1 file changed, 1 insertion(+) diff --git a/python/tvm/relay/op/strategy/cuda.py b/python/tvm/relay/op/strategy/cuda.py index 88c930ead282..3a3e3037d274 100644 --- a/python/tvm/relay/op/strategy/cuda.py +++ b/python/tvm/relay/op/strategy/cuda.py @@ -1216,6 +1216,7 @@ def invert_permutation_strategy_cuda(attrs, inputs, out_type, target): def einsum_strategy_cuda(attrs, inputs, out_type, target): """einsum cuda strategy""" strategy = _op.OpStrategy() + # TODO: Add cuda-specific op implementation for einsum strategy.add_implementation( wrap_compute_einsum(topi.einsum), wrap_topi_schedule(topi.generic.schedule_extern), From 3481e567968c83a2190033d60ca4531078a79f35 Mon Sep 17 00:00:00 2001 From: An Wang Date: Tue, 14 Sep 2021 14:29:18 -0700 Subject: [PATCH 6/7] jostle ci From 69974b6dcd78f3dd5b62ff64529320575627512a Mon Sep 17 00:00:00 2001 From: An Wang Date: Tue, 14 Sep 2021 15:37:57 -0700 Subject: [PATCH 7/7] jostle ci