From f1332a40fe666c67ee02844c219db3d966821d9d Mon Sep 17 00:00:00 2001 From: Valery Chernov Date: Fri, 24 Feb 2023 09:38:29 +0300 Subject: [PATCH 1/2] support mean reduction, clean comments, extend tests --- include/tvm/relay/attrs/transform.h | 2 +- python/tvm/relay/frontend/pytorch.py | 3 --- python/tvm/relay/op/transform.py | 5 +++-- python/tvm/topi/cuda/scatter_elements.py | 10 ++++++++-- python/tvm/topi/scatter_elements.py | 10 ++++++++-- tests/python/frontend/pytorch/test_forward.py | 6 ++---- 6 files changed, 22 insertions(+), 14 deletions(-) diff --git a/include/tvm/relay/attrs/transform.h b/include/tvm/relay/attrs/transform.h index 7680883248e0..51378c86979d 100644 --- a/include/tvm/relay/attrs/transform.h +++ b/include/tvm/relay/attrs/transform.h @@ -164,7 +164,7 @@ struct ScatterElementsAttrs : public tvm::AttrsNode { TVM_ATTR_FIELD(axis).set_default(0).describe("The axis over which to select values."); TVM_ATTR_FIELD(reduction).set_default("update").describe( "Reduction mode of the scatter elements, " - "either \"update\", \"add\", \"mul\", \"min\" or \"max\"."); + "either \"update\", \"add\", \"mul\", \"mean\", \"min\" or \"max\"."); } }; diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py index 635cb960a829..57997bb89416 100644 --- a/python/tvm/relay/frontend/pytorch.py +++ b/python/tvm/relay/frontend/pytorch.py @@ -2766,9 +2766,6 @@ def scatter_reduce(self, inputs, input_types): reduce = "min" elif reduce == "amax": reduce = "max" - else: # reduce == "mean" - # TODO(vvchernov): support mean reduction - raise NotImplementedError("Mean reduction has not been supported yet!") return _op.scatter_elements(data, index, src, axis=dim, reduction=reduce) diff --git a/python/tvm/relay/op/transform.py b/python/tvm/relay/op/transform.py index 833d14eb5897..6718347b3168 100644 --- a/python/tvm/relay/op/transform.py +++ b/python/tvm/relay/op/transform.py @@ -397,10 +397,11 @@ def scatter_elements(data, indices, updates, axis=0, reduction="update"): The axis to scatter elements on. It is zero by default. reduction : string, optional - The reduction mode for scatter. Choise is from ["update", "add", "mul", "min", max"] + The reduction mode for scatter. Choise is from ["update", "add", "mul", "mean", "min", max"] If update, the update values will replace the input data If add, the update values will be added to the input data - If mul, the update values will be multiply to the input data + If mul, the input data will be multiplied on the update values + If mean, the input data will be mean between the update values and the input data If min, there is choice of minimal between the update values and the input data If max, there is choice of maximal between the update values and the input data It is "update" by default diff --git a/python/tvm/topi/cuda/scatter_elements.py b/python/tvm/topi/cuda/scatter_elements.py index 25f15a0e73a6..e5fa03311692 100644 --- a/python/tvm/topi/cuda/scatter_elements.py +++ b/python/tvm/topi/cuda/scatter_elements.py @@ -234,10 +234,11 @@ def scatter_elements(data, indices, updates, axis=0, reduction="update"): The axis to scatter on. It is zero by default. reduction : optional, string - The update mode for the algorithm, either "update", "add", "mul", "min" or "max" + The update mode for the algorithm, either "update", "add", "mul", "mean", "min" or "max" If update, the update values will replace the input data If add, the update values will be added to the input data - If mul, the update values will be multiply to the input data + If mul, the input data will be multiplied on the update values + If mean, the input data will be mean between the update values and the input data If min, there is choice of minimal between the update values and the input data If max, there is choice of maximal between the update values and the input data It is "update" by default @@ -258,6 +259,9 @@ def add_func(dst_ptr, dst_index, update): def mul_func(dst_ptr, dst_index, update): dst_ptr[dst_index] *= update + def mean_func(dst_ptr, dst_index, update): + dst_ptr[dst_index] = (dst_ptr[dst_index] + update)/2 + def min_func(dst_ptr, dst_index, update): dst_ptr[dst_index] = tir.min(dst_ptr[dst_index], update) @@ -271,6 +275,8 @@ def max_func(dst_ptr, dst_index, update): reduce_func = add_func elif reduction == "mul": reduce_func = mul_func + elif reduction == "mean": + reduce_func = mean_func elif reduction == "min": reduce_func = min_func elif reduction == "max": diff --git a/python/tvm/topi/scatter_elements.py b/python/tvm/topi/scatter_elements.py index b4052702268b..9acf618321da 100644 --- a/python/tvm/topi/scatter_elements.py +++ b/python/tvm/topi/scatter_elements.py @@ -54,7 +54,8 @@ def scatter_elements(data, indices, updates, axis=0, reduction="update"): The update mode for the algorithm, either "update", "add", "mul", "min" or "max" If update, the update values will replace the input data If add, the update values will be added to the input data - If mul, the update values will be multiply to the input data + If mul, the input data will be multiplied on the update values + If mean, the input data will be mean between the update values and the input data If min, there is choice of minimal between the update values and the input data If max, there is choice of maximal between the update values and the input data It is "update" by default @@ -133,6 +134,9 @@ def add_func(dst_ptr, dst_index, update): def mul_func(dst_ptr, dst_index, update): dst_ptr[dst_index] *= update + def mean_func(dst_ptr, dst_index, update): + dst_ptr[dst_index] = (dst_ptr[dst_index] + update)/2 + def min_func(dst_ptr, dst_index, update): dst_ptr[dst_index] = tir.min(dst_ptr[dst_index], update) @@ -146,13 +150,15 @@ def max_func(dst_ptr, dst_index, update): reduce_func = add_func elif reduction == "mul": reduce_func = mul_func + elif reduction == "mean": + reduce_func = mean_func elif reduction == "min": reduce_func = min_func elif reduction == "max": reduce_func = max_func else: raise NotImplementedError( - "scatter_elements reduction not in [update, add, mul, min, max]:", reduction + "scatter_elements reduction not in [update, add, mul, mean, min, max]:", reduction ) out_buf = tir.decl_buffer(data.shape, data.dtype, "out_buf") diff --git a/tests/python/frontend/pytorch/test_forward.py b/tests/python/frontend/pytorch/test_forward.py index 21defa6a59b2..2401e98bceb1 100644 --- a/tests/python/frontend/pytorch/test_forward.py +++ b/tests/python/frontend/pytorch/test_forward.py @@ -4252,16 +4252,14 @@ def test_fn_scatter_reduce(dim, reduce): in_src = torch.rand(2, 5) - 1 targets = ["llvm", "cuda"] - # TODO(vvchernov): support test of mean reduction and include_self=False - for reduce in ["sum", "prod", "amin", "amax"]: + for reduce in ["sum", "prod", "amin", "amax", "mean"]: verify_trace_model(test_fn_scatter_reduce(0, reduce), [in_data, in_index, in_src], targets) in_data = torch.rand(2, 4) - 1 in_index = torch.tensor([[2], [3]]) in_src = torch.rand(2, 1) - 1 - # TODO(vvchernov): support test of mean reduction and include_self=False - for reduce in ["sum", "prod", "amin", "amax"]: + for reduce in ["sum", "prod", "amin", "amax", "mean"]: verify_trace_model(test_fn_scatter_reduce(1, reduce), [in_data, in_index, in_src], targets) From 075d3bea2a0af6f3d83b72d9f8763eabfee98b5c Mon Sep 17 00:00:00 2001 From: Valery Chernov Date: Fri, 24 Feb 2023 10:16:00 +0300 Subject: [PATCH 2/2] fix pylint --- python/tvm/topi/cuda/scatter_elements.py | 2 +- python/tvm/topi/scatter_elements.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/python/tvm/topi/cuda/scatter_elements.py b/python/tvm/topi/cuda/scatter_elements.py index e5fa03311692..2f345b9d67ec 100644 --- a/python/tvm/topi/cuda/scatter_elements.py +++ b/python/tvm/topi/cuda/scatter_elements.py @@ -260,7 +260,7 @@ def mul_func(dst_ptr, dst_index, update): dst_ptr[dst_index] *= update def mean_func(dst_ptr, dst_index, update): - dst_ptr[dst_index] = (dst_ptr[dst_index] + update)/2 + dst_ptr[dst_index] = (dst_ptr[dst_index] + update) / 2 def min_func(dst_ptr, dst_index, update): dst_ptr[dst_index] = tir.min(dst_ptr[dst_index], update) diff --git a/python/tvm/topi/scatter_elements.py b/python/tvm/topi/scatter_elements.py index 9acf618321da..4c35578ffd36 100644 --- a/python/tvm/topi/scatter_elements.py +++ b/python/tvm/topi/scatter_elements.py @@ -135,7 +135,7 @@ def mul_func(dst_ptr, dst_index, update): dst_ptr[dst_index] *= update def mean_func(dst_ptr, dst_index, update): - dst_ptr[dst_index] = (dst_ptr[dst_index] + update)/2 + dst_ptr[dst_index] = (dst_ptr[dst_index] + update) / 2 def min_func(dst_ptr, dst_index, update): dst_ptr[dst_index] = tir.min(dst_ptr[dst_index], update)