diff --git a/include/tvm/relay/attrs/transform.h b/include/tvm/relay/attrs/transform.h index a8317e1e51ad..e3f9bad17ef5 100644 --- a/include/tvm/relay/attrs/transform.h +++ b/include/tvm/relay/attrs/transform.h @@ -475,7 +475,7 @@ struct ScanopAttrs : public tvm::AttrsNode { .describe("The first element is not included") .set_default(Bool(false)); } -}; +}; // struct ScanopAttrs /*! \brief Attributes used in unique operator */ struct UniqueAttrs : public tvm::AttrsNode { @@ -489,6 +489,15 @@ struct UniqueAttrs : public tvm::AttrsNode { } }; // struct UniqueAttrs +/*! \brief Attributes used in einsum operator */ +struct EinsumAttrs : public tvm::AttrsNode { + String equation; + + TVM_DECLARE_ATTRS(EinsumAttrs, "relay.attrs.EinsumAttrs") { + TVM_ATTR_FIELD(equation).describe("The einsum expression string"); + } +}; // struct EinsumAttrs + } // namespace relay } // namespace tvm #endif // TVM_RELAY_ATTRS_TRANSFORM_H_ diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py index 48a5b48b37d3..1b32ea4e9e88 100644 --- a/python/tvm/relay/frontend/onnx.py +++ b/python/tvm/relay/frontend/onnx.py @@ -3517,6 +3517,15 @@ def _impl_v11(cls, inputs, attr, params): return _expr.TupleWrapper(_expr.Tuple([unique_vals, indices, inverse_indices, counts]), 4) +class Einsum(OnnxOpConverter): + """Operator converter for Einsum""" + + @classmethod + def _impl_v12(cls, inputs, attr, params): + equation = attr["equation"].decode("utf-8") + return _op.einsum(inputs, equation) + + class RandomUniform(OnnxOpConverter): """Operator converter for random_uniform""" @@ -3763,6 +3772,7 @@ def _get_convert_map(opset): "Range": Range.get_converter(opset), "CumSum": CumSum.get_converter(opset), "Unique": Unique.get_converter(opset), + "Einsum": Einsum.get_converter(opset), # defs/control_flow "Loop": Loop.get_converter(opset), "If": If.get_converter(opset), diff --git a/python/tvm/relay/op/__init__.py b/python/tvm/relay/op/__init__.py index 2e509a111c4a..825bd1f627ca 100644 --- a/python/tvm/relay/op/__init__.py +++ b/python/tvm/relay/op/__init__.py @@ -54,6 +54,7 @@ from . import _transform from . import _reduce from . import _algorithm +from . import _math def _register_op_make(): diff --git a/python/tvm/relay/op/_math.py b/python/tvm/relay/op/_math.py new file mode 100644 index 000000000000..ff74fafcef75 --- /dev/null +++ b/python/tvm/relay/op/_math.py @@ -0,0 +1,22 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Backend compiler related feature registration""" +from . import op as _reg +from . import strategy + +# einsum +_reg.register_strategy("einsum", strategy.einsum_strategy) diff --git a/python/tvm/relay/op/_transform.py b/python/tvm/relay/op/_transform.py index 7474f8e22bb5..0284d2483ce5 100644 --- a/python/tvm/relay/op/_transform.py +++ b/python/tvm/relay/op/_transform.py @@ -182,6 +182,7 @@ def compute_unique(attrs, inputs, output_type): _reg.register_strategy("invert_permutation", strategy.invert_permutation_strategy) _reg.register_shape_func("invert_permutation", False, elemwise_shape_func) + ##################### # Shape functions # ##################### diff --git a/python/tvm/relay/op/strategy/cuda.py b/python/tvm/relay/op/strategy/cuda.py index ba47ae7bc4f1..3a3e3037d274 100644 --- a/python/tvm/relay/op/strategy/cuda.py +++ b/python/tvm/relay/op/strategy/cuda.py @@ -1210,3 +1210,16 @@ def invert_permutation_strategy_cuda(attrs, inputs, out_type, target): name="invert_permutation.cuda", ) return strategy + + +@einsum_strategy.register(["cuda", "gpu"]) +def einsum_strategy_cuda(attrs, inputs, out_type, target): + """einsum cuda strategy""" + strategy = _op.OpStrategy() + # TODO: Add cuda-specific op implementation for einsum + strategy.add_implementation( + wrap_compute_einsum(topi.einsum), + wrap_topi_schedule(topi.generic.schedule_extern), + name="einsum.cuda", + ) + return strategy diff --git a/python/tvm/relay/op/strategy/generic.py b/python/tvm/relay/op/strategy/generic.py index 9c756f201721..2822585caeaf 100644 --- a/python/tvm/relay/op/strategy/generic.py +++ b/python/tvm/relay/op/strategy/generic.py @@ -1669,3 +1669,24 @@ def invert_permutation_strategy(attrs, inputs, out_type, target): name="invert_permutation.generic", ) return strategy + + +def wrap_compute_einsum(topi_compute): + """Wrap einsum topi compute""" + + def _compute_einsum(attrs, inputs, _): + return [topi_compute(attrs.equation, *inputs)] + + return _compute_einsum + + +@override_native_generic_func("einsum_strategy") +def einsum_strategy(attrs, inputs, out_type, target): + """einsum generic strategy""" + strategy = _op.OpStrategy() + strategy.add_implementation( + wrap_compute_einsum(topi.einsum), + wrap_topi_schedule(topi.generic.schedule_einsum), + name="einsum.generic", + ) + return strategy diff --git a/python/tvm/relay/op/tensor.py b/python/tvm/relay/op/tensor.py index a38a23064d6f..e47928919ce1 100644 --- a/python/tvm/relay/op/tensor.py +++ b/python/tvm/relay/op/tensor.py @@ -1104,6 +1104,29 @@ def concatenate(data, axis): return _make.concatenate(Tuple(data), axis) +def einsum(data, equation): + """Evaluates the Einstein summation convention on data + + Parameters + ---------- + data : Union(List[relay.Expr], Tuple[relay.Expr]) + A list of tensors. + equation : str + The einsum expression string. + + Returns + ------- + result : relay.Expr + The output tensor from the einsum op. + """ + data = list(data) + if not data: + raise ValueError("relay.einsum requires data to be non-empty.") + if not isinstance(equation, str): + raise ValueError("einsum `equation` must be a str") + return _make.einsum(Tuple(data), equation) + + def stack(data, axis): """Join a sequence of arrays along a new axis. diff --git a/python/tvm/topi/generic/__init__.py b/python/tvm/topi/generic/__init__.py index cc64abab8ed8..021f9a1bbe1d 100644 --- a/python/tvm/topi/generic/__init__.py +++ b/python/tvm/topi/generic/__init__.py @@ -39,3 +39,4 @@ from .sort import * from .search import * from .image import * +from .math import * diff --git a/python/tvm/topi/generic/math.py b/python/tvm/topi/generic/math.py new file mode 100644 index 000000000000..3af6cd16a374 --- /dev/null +++ b/python/tvm/topi/generic/math.py @@ -0,0 +1,34 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Generic math operators""" +from .default import default_schedule as _default_schedule + + +def schedule_einsum(outs): + """Schedule for einsum operator. + + Parameters + ---------- + outs: Array of Tensor + The computation graph description of einsum. + + Returns + ------- + s: Schedule + The computation schedule for the op. + """ + return _default_schedule(outs, False) diff --git a/src/relay/op/tensor/math.cc b/src/relay/op/tensor/math.cc new file mode 100644 index 000000000000..246fba62cc66 --- /dev/null +++ b/src/relay/op/tensor/math.cc @@ -0,0 +1,115 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file math.cc + * \brief Math operators. + */ +#include +#include +#include + +#include "../make_op.h" +#include "../op_common.h" +#include "../type_relations.h" + +namespace tvm { +namespace relay { + +// relay.einsum +TVM_REGISTER_NODE_TYPE(EinsumAttrs); + +bool EinsumRel(const Array& types, int num_inputs, const Attrs& attrs, + const TypeReporter& reporter) { + // Check attrs + const EinsumAttrs* param = attrs.as(); + if (param == nullptr) { + reporter->GetDiagCtx().EmitFatal(Diagnostic::Error(reporter->GetSpan()) + << "the call attributes are not defined"); + return false; + } + + // types: [data, result] + ICHECK_EQ(types.size(), 2) << "the arity of einsum is 2, not " << types.size(); + + // Check input type is a tuple. + const auto* tensor_tuple = types[0].as(); + if (tensor_tuple == nullptr) { + reporter->GetDiagCtx().EmitFatal( + Diagnostic::Error(reporter->GetSpan()) + << "einsum requires a tuple of tensors as the first argument, found " + << PrettyPrint(types[0])); + return false; + } + + // Check the input tuple consists of tensors with consistent dtype. + const auto& first = Downcast(tensor_tuple->fields[0]); + const DataType dtype = first->dtype; + std::vector> input_shapes; + for (const Type& ele : tensor_tuple->fields) { + if (ele.as()) { + return false; + } + + const auto& e = Downcast(ele); + + const DataType& e_dtype = e->dtype; + if (e_dtype != dtype) { + throw Error("relay.einsum requires all tensors have the same dtype"); + } + input_shapes.push_back(e->shape); + } + + // Calculate output shape + Array oshape = topi::NumpyEinsumShape(param->equation, input_shapes); + + auto rtype = TensorType(oshape, dtype); + reporter->Assign(types[1], rtype); + return true; +} + +Array EinsumCompute(const Attrs& attrs, const Array& inputs, + const Type& out_type) { + const EinsumAttrs* param = attrs.as(); + ICHECK(param != nullptr); + return Array{topi::einsum(param->equation, inputs)}; +} + +Expr MakeEinsum(Expr data, String equation) { + auto attrs = make_object(); + attrs->equation = std::move(equation); + static const Op& op = Op::Get("einsum"); + return Call(op, {data}, Attrs(attrs), {}); +} + +TVM_REGISTER_GLOBAL("relay.op._make.einsum").set_body_typed(MakeEinsum); + +RELAY_REGISTER_OP("einsum") + .describe(R"doc(Evaluates the Einstein summation convention +on the operands)doc" TVM_ADD_FILELINE) + .set_attrs_type() + .set_num_inputs(1) + .add_argument("data", "Tuple of Tensors", "The input list of tensors.") + .set_support_level(11) + .add_type_rel("Einsum", EinsumRel) + .set_attr("FTVMCompute", EinsumCompute) + .set_attr("TOpPattern", kInjective); + +} // namespace relay +} // namespace tvm diff --git a/tests/python/frontend/onnx/test_forward.py b/tests/python/frontend/onnx/test_forward.py index 6fa4fcf47ccb..f9094e94a3ee 100644 --- a/tests/python/frontend/onnx/test_forward.py +++ b/tests/python/frontend/onnx/test_forward.py @@ -4739,11 +4739,6 @@ def verify_eyelike(indata): "test_dropout_default_mask", "test_dropout_default_mask_ratio", "test_dropout_default_ratio", - "test_einsum_batch_diagonal", - "test_einsum_batch_matmul", - "test_einsum_inner_prod", - "test_einsum_sum", - "test_einsum_transpose", "test_greater_equal", "test_greater_equal_bcast", "test_if_seq",