From bfcac8c240760b1db700629cd9a22b93e4214c8c Mon Sep 17 00:00:00 2001 From: Manupa Karunaratne Date: Fri, 2 Jul 2021 14:51:46 +0100 Subject: [PATCH 1/5] [TIR][USMP] Add a parallel to serial for loop converter pass This is an optional pass to convert all parallel for loops in TIR to serial ones for different reasons such as executor does not support parallel launch of for loops (e.g., AoT) or allocating space for parallel for loops might not be desired. * Additionally adding FFI scaffolding for USMP Change-Id: Id5e8ccb90140d2d3ae113b20a3ca152a54497c45 --- include/tvm/tir/usmp/transform.h | 44 +++++++++++++ python/tvm/tir/__init__.py | 1 + python/tvm/tir/usmp/__init__.py | 21 +++++++ python/tvm/tir/usmp/_ffi_api.py | 21 +++++++ python/tvm/tir/usmp/transform/__init__.py | 20 ++++++ python/tvm/tir/usmp/transform/_ffi_api.py | 21 +++++++ python/tvm/tir/usmp/transform/transform.py | 36 +++++++++++ .../transform/convert_for_loops_serial.cc | 62 +++++++++++++++++++ ...usmp_transform_convert_for_loops_serial.py | 60 ++++++++++++++++++ 9 files changed, 286 insertions(+) create mode 100644 include/tvm/tir/usmp/transform.h create mode 100644 python/tvm/tir/usmp/__init__.py create mode 100644 python/tvm/tir/usmp/_ffi_api.py create mode 100644 python/tvm/tir/usmp/transform/__init__.py create mode 100644 python/tvm/tir/usmp/transform/_ffi_api.py create mode 100644 python/tvm/tir/usmp/transform/transform.py create mode 100644 src/tir/usmp/transform/convert_for_loops_serial.cc create mode 100644 tests/python/unittest/test_tir_usmp_transform_convert_for_loops_serial.py diff --git a/include/tvm/tir/usmp/transform.h b/include/tvm/tir/usmp/transform.h new file mode 100644 index 000000000000..32eca31d2aba --- /dev/null +++ b/include/tvm/tir/usmp/transform.h @@ -0,0 +1,44 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tvm/tir/analysis.h + * \brief Analysis utilities and passes for TIR Unified Static Memory Planner. + */ +#ifndef TVM_TIR_USMP_TRANSFORM_H_ +#define TVM_TIR_USMP_TRANSFORM_H_ + +#include +#include +#include +#include +#include +#include + +namespace tvm { +namespace tir { +namespace usmp { + +TVM_DLL Stmt ConvertForLoopsToSerial(const PrimFunc& func); + +} +} // namespace tir +} // namespace tvm + +#endif // TVM_TIR_USMP_TRANSFORM_H_ diff --git a/python/tvm/tir/__init__.py b/python/tvm/tir/__init__.py index eb200df0c599..a7f9ec22ae4f 100644 --- a/python/tvm/tir/__init__.py +++ b/python/tvm/tir/__init__.py @@ -55,3 +55,4 @@ from . import transform from . import analysis from . import stmt_functor +from . import usmp diff --git a/python/tvm/tir/usmp/__init__.py b/python/tvm/tir/usmp/__init__.py new file mode 100644 index 000000000000..c1a08565d18a --- /dev/null +++ b/python/tvm/tir/usmp/__init__.py @@ -0,0 +1,21 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=unused-import, redefined-builtin +"""Namespace for Unified Static Memory Planner""" + +from . import transform +from . import analysis diff --git a/python/tvm/tir/usmp/_ffi_api.py b/python/tvm/tir/usmp/_ffi_api.py new file mode 100644 index 000000000000..5899ef0c86ea --- /dev/null +++ b/python/tvm/tir/usmp/_ffi_api.py @@ -0,0 +1,21 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""FFI APIs for tvm.tir.usmp""" +import tvm._ffi + + +tvm._ffi._init_api("tir.usmp", __name__) diff --git a/python/tvm/tir/usmp/transform/__init__.py b/python/tvm/tir/usmp/transform/__init__.py new file mode 100644 index 000000000000..2835d146dbd1 --- /dev/null +++ b/python/tvm/tir/usmp/transform/__init__.py @@ -0,0 +1,20 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=unused-import, redefined-builtin +"""Namespace for USMP's transform passes""" + +from .transform import for_loop_serial_converter diff --git a/python/tvm/tir/usmp/transform/_ffi_api.py b/python/tvm/tir/usmp/transform/_ffi_api.py new file mode 100644 index 000000000000..67684b34a3f7 --- /dev/null +++ b/python/tvm/tir/usmp/transform/_ffi_api.py @@ -0,0 +1,21 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""FFI APIs for tvm.tir.usmp.transform""" +import tvm._ffi + + +tvm._ffi._init_api("tir.usmp.transform", __name__) diff --git a/python/tvm/tir/usmp/transform/transform.py b/python/tvm/tir/usmp/transform/transform.py new file mode 100644 index 000000000000..5b7ec64af556 --- /dev/null +++ b/python/tvm/tir/usmp/transform/transform.py @@ -0,0 +1,36 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""USMP Transform Python API for passes""" +# pylint: disable=invalid-name +from . import _ffi_api +from ...function import PrimFunc + + +def for_loop_serial_converter(func: PrimFunc): + """Convert Parallel For Loop to Serial. + + Parameters + ---------- + func: tvm.tir.PrimFunc + The function to be converted. + + Returns + ------- + tvm.tir.PrimFunc + converted function + """ + return _ffi_api.for_loop_serial_converter(func) diff --git a/src/tir/usmp/transform/convert_for_loops_serial.cc b/src/tir/usmp/transform/convert_for_loops_serial.cc new file mode 100644 index 000000000000..f71a09430762 --- /dev/null +++ b/src/tir/usmp/transform/convert_for_loops_serial.cc @@ -0,0 +1,62 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tir/analysis/usmp/convert_for_loops_serial.cc + * \brief Convert all for loops to serial for lesser memory consumption + */ +#include +#include +#include +#include +#include + +namespace tvm { +namespace tir { +namespace usmp { + +class ForLoopSerialConverter : public StmtExprMutator { + public: + ForLoopSerialConverter() = default; + Stmt operator()(const PrimFunc& func); + + private: + Stmt VisitStmt_(const ForNode* op) override; +}; + +Stmt ForLoopSerialConverter::VisitStmt_(const ForNode* op) { + if (op->kind == ForKind::kParallel) { + return For(op->loop_var, op->min, op->extent, ForKind::kSerial, op->body, op->thread_binding, + op->annotations, op->span); + } + return StmtExprMutator::VisitStmt_(op); +} + +Stmt ForLoopSerialConverter::operator()(const PrimFunc& func) { + return this->VisitStmt(func->body); +} + +Stmt ConvertForLoopsToSerial(const PrimFunc& func) { return ForLoopSerialConverter()(func); } + +TVM_REGISTER_GLOBAL("tir.usmp.transform.for_loop_serial_converter") + .set_body_typed([](PrimFunc func) { return (ConvertForLoopsToSerial(func)); }); + +} // namespace usmp +} // namespace tir +} // namespace tvm diff --git a/tests/python/unittest/test_tir_usmp_transform_convert_for_loops_serial.py b/tests/python/unittest/test_tir_usmp_transform_convert_for_loops_serial.py new file mode 100644 index 000000000000..1eb64227602e --- /dev/null +++ b/tests/python/unittest/test_tir_usmp_transform_convert_for_loops_serial.py @@ -0,0 +1,60 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import pytest + +import tvm +from tvm import tir, script +from tvm.script import ty +from tvm.tir import stmt_functor + +# fmt: off +@tvm.script.tir +def fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast_2(placeholder_30: ty.handle, placeholder_31: ty.handle, placeholder_32: ty.handle, T_cast_8: ty.handle) -> None: + # function attr dict + tir.func_attr({"global_symbol": "fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast_2", "tir.noalias": True}) + placeholder_33 = tir.match_buffer(placeholder_30, [1, 28, 28, 192], dtype="int16", elem_offset=0, align=128, offset_factor=1) + placeholder_34 = tir.match_buffer(placeholder_31, [1, 1, 192, 16], dtype="int16", elem_offset=0, align=128, offset_factor=1) + placeholder_35 = tir.match_buffer(placeholder_32, [1, 1, 1, 16], dtype="int32", elem_offset=0, align=128, offset_factor=1) + T_cast_9 = tir.match_buffer(T_cast_8, [1, 28, 28, 16], dtype="int16", elem_offset=0, align=128, offset_factor=1) + # body + PaddedInput_3 = tir.allocate([1, 28, 28, 192], "int16", "global") + for i0_i1_fused_3 in tir.parallel(0, 28): + for i2_3, i3_3 in tir.grid(28, 192): + tir.store(PaddedInput_3, (((i0_i1_fused_3*5376) + (i2_3*192)) + i3_3), tir.load("int16", placeholder_33.data, (((i0_i1_fused_3*5376) + (i2_3*192)) + i3_3)), True) + for ax0_ax1_fused_ax2_fused_3 in tir.parallel(0, 784): + for ax3_2 in tir.serial(0, 16): + Conv2dOutput_3 = tir.allocate([1, 1, 1, 1], "int32", "global") + tir.store(Conv2dOutput_3, 0, 0, True) + for rc_3 in tir.serial(0, 192): + tir.store(Conv2dOutput_3, 0, (tir.load("int32", Conv2dOutput_3, 0) + (tir.cast(tir.load("int16", PaddedInput_3, ((ax0_ax1_fused_ax2_fused_3*192) + rc_3)), "int32")*tir.cast(tir.load("int16", placeholder_34.data, ((rc_3*16) + ax3_2)), "int32"))), True) + tir.store(T_cast_9.data, ((ax0_ax1_fused_ax2_fused_3*16) + ax3_2), tir.cast(tir.cast(tir.max(tir.min(tir.q_multiply_shift((tir.load("int32", Conv2dOutput_3, 0) + tir.load("int32", placeholder_35.data, ax3_2)), 1764006585, 31, -7, dtype="int32"), 255), 0), "uint8"), "int16"), True) +# fmt: on + + +def test_nn_conv2d_add_fixed_point_multiply_clip_cast_cast_2(): + primfunc = fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast_2 + primfunc = tvm.tir.usmp.transform.for_loop_serial_converter(primfunc) + + def verify_serial_loops(stmt): + if isinstance(stmt, tvm.tir.For): + assert stmt.kind == tvm.tir.ForKind.SERIAL + + stmt_functor.post_order_visit(primfunc.body, verify_serial_loops) + + +if __name__ == "__main__": + pytest.main([__file__]) From 2b9f5d7dba3ea38d46537572fda99230cb61db49 Mon Sep 17 00:00:00 2001 From: Manupa Karunaratne Date: Wed, 14 Jul 2021 18:24:07 +0100 Subject: [PATCH 2/5] [TIR][USMP] Add a parallel to serial for loop converter pass * remove unused import Change-Id: I29d5fdec92120418596f9dba1d6630f65620a603 --- python/tvm/tir/usmp/__init__.py | 1 - 1 file changed, 1 deletion(-) diff --git a/python/tvm/tir/usmp/__init__.py b/python/tvm/tir/usmp/__init__.py index c1a08565d18a..ca86890f22cf 100644 --- a/python/tvm/tir/usmp/__init__.py +++ b/python/tvm/tir/usmp/__init__.py @@ -18,4 +18,3 @@ """Namespace for Unified Static Memory Planner""" from . import transform -from . import analysis From 3dac79975b20c2ed4306ab208d4327e5456e6ac5 Mon Sep 17 00:00:00 2001 From: Manupa Karunaratne Date: Tue, 20 Jul 2021 10:15:10 +0100 Subject: [PATCH 3/5] [TIR][USMP] Add a parallel to serial for loop converter pass *moved the pass to tir namespace Change-Id: I74720ca2f566066b3a4f22f504d8f0f684c99dc2 --- include/tvm/tir/transform.h | 9 ++++ include/tvm/tir/usmp/transform.h | 44 ------------------- python/tvm/tir/__init__.py | 1 - python/tvm/tir/transform/transform.py | 16 +++++++ python/tvm/tir/usmp/__init__.py | 20 --------- python/tvm/tir/usmp/_ffi_api.py | 21 --------- python/tvm/tir/usmp/transform/__init__.py | 20 --------- python/tvm/tir/usmp/transform/_ffi_api.py | 21 --------- python/tvm/tir/usmp/transform/transform.py | 36 --------------- .../convert_for_loops_serial.cc | 27 +++++++++--- ...tir_transform_convert_for_loops_serial.py} | 6 ++- 11 files changed, 49 insertions(+), 172 deletions(-) delete mode 100644 include/tvm/tir/usmp/transform.h delete mode 100644 python/tvm/tir/usmp/__init__.py delete mode 100644 python/tvm/tir/usmp/_ffi_api.py delete mode 100644 python/tvm/tir/usmp/transform/__init__.py delete mode 100644 python/tvm/tir/usmp/transform/_ffi_api.py delete mode 100644 python/tvm/tir/usmp/transform/transform.py rename src/tir/{usmp/transform => transforms}/convert_for_loops_serial.cc (72%) rename tests/python/unittest/{test_tir_usmp_transform_convert_for_loops_serial.py => test_tir_transform_convert_for_loops_serial.py} (93%) diff --git a/include/tvm/tir/transform.h b/include/tvm/tir/transform.h index d1308fe0059e..744522547b5c 100644 --- a/include/tvm/tir/transform.h +++ b/include/tvm/tir/transform.h @@ -442,6 +442,15 @@ TVM_DLL Pass FlattenBuffer(); */ TVM_DLL Pass MergeDynamicSharedMemoryAllocations(); +/*! + * \brief This pass is post-scheduling pass to convert all + * Parallel For loops to Serial ones. This is run + * to attain lesser memory and/or executor/backend + * does not support parallel launch of For loops. + * \return The pass. + */ +TVM_DLL Pass ConvertForLoopsToSerial(); + } // namespace transform } // namespace tir } // namespace tvm diff --git a/include/tvm/tir/usmp/transform.h b/include/tvm/tir/usmp/transform.h deleted file mode 100644 index 32eca31d2aba..000000000000 --- a/include/tvm/tir/usmp/transform.h +++ /dev/null @@ -1,44 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * \file tvm/tir/analysis.h - * \brief Analysis utilities and passes for TIR Unified Static Memory Planner. - */ -#ifndef TVM_TIR_USMP_TRANSFORM_H_ -#define TVM_TIR_USMP_TRANSFORM_H_ - -#include -#include -#include -#include -#include -#include - -namespace tvm { -namespace tir { -namespace usmp { - -TVM_DLL Stmt ConvertForLoopsToSerial(const PrimFunc& func); - -} -} // namespace tir -} // namespace tvm - -#endif // TVM_TIR_USMP_TRANSFORM_H_ diff --git a/python/tvm/tir/__init__.py b/python/tvm/tir/__init__.py index a7f9ec22ae4f..eb200df0c599 100644 --- a/python/tvm/tir/__init__.py +++ b/python/tvm/tir/__init__.py @@ -55,4 +55,3 @@ from . import transform from . import analysis from . import stmt_functor -from . import usmp diff --git a/python/tvm/tir/transform/transform.py b/python/tvm/tir/transform/transform.py index 537499a27fa9..53dead3f83c3 100644 --- a/python/tvm/tir/transform/transform.py +++ b/python/tvm/tir/transform/transform.py @@ -678,3 +678,19 @@ def MergeDynamicSharedMemoryAllocations(): The result pass """ return _ffi_api.MergeDynamicSharedMemoryAllocations() # type: ignore + + +def ConvertForLoopsToSerial(): + """Convert Parallel For Loop to Serial. + + Parameters + ---------- + func: tvm.tir.PrimFunc + The function to be converted. + + Returns + ------- + tvm.tir.PrimFunc + converted function + """ + return _ffi_api.ConvertForLoopsToSerial() diff --git a/python/tvm/tir/usmp/__init__.py b/python/tvm/tir/usmp/__init__.py deleted file mode 100644 index ca86890f22cf..000000000000 --- a/python/tvm/tir/usmp/__init__.py +++ /dev/null @@ -1,20 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -# pylint: disable=unused-import, redefined-builtin -"""Namespace for Unified Static Memory Planner""" - -from . import transform diff --git a/python/tvm/tir/usmp/_ffi_api.py b/python/tvm/tir/usmp/_ffi_api.py deleted file mode 100644 index 5899ef0c86ea..000000000000 --- a/python/tvm/tir/usmp/_ffi_api.py +++ /dev/null @@ -1,21 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""FFI APIs for tvm.tir.usmp""" -import tvm._ffi - - -tvm._ffi._init_api("tir.usmp", __name__) diff --git a/python/tvm/tir/usmp/transform/__init__.py b/python/tvm/tir/usmp/transform/__init__.py deleted file mode 100644 index 2835d146dbd1..000000000000 --- a/python/tvm/tir/usmp/transform/__init__.py +++ /dev/null @@ -1,20 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -# pylint: disable=unused-import, redefined-builtin -"""Namespace for USMP's transform passes""" - -from .transform import for_loop_serial_converter diff --git a/python/tvm/tir/usmp/transform/_ffi_api.py b/python/tvm/tir/usmp/transform/_ffi_api.py deleted file mode 100644 index 67684b34a3f7..000000000000 --- a/python/tvm/tir/usmp/transform/_ffi_api.py +++ /dev/null @@ -1,21 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""FFI APIs for tvm.tir.usmp.transform""" -import tvm._ffi - - -tvm._ffi._init_api("tir.usmp.transform", __name__) diff --git a/python/tvm/tir/usmp/transform/transform.py b/python/tvm/tir/usmp/transform/transform.py deleted file mode 100644 index 5b7ec64af556..000000000000 --- a/python/tvm/tir/usmp/transform/transform.py +++ /dev/null @@ -1,36 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""USMP Transform Python API for passes""" -# pylint: disable=invalid-name -from . import _ffi_api -from ...function import PrimFunc - - -def for_loop_serial_converter(func: PrimFunc): - """Convert Parallel For Loop to Serial. - - Parameters - ---------- - func: tvm.tir.PrimFunc - The function to be converted. - - Returns - ------- - tvm.tir.PrimFunc - converted function - """ - return _ffi_api.for_loop_serial_converter(func) diff --git a/src/tir/usmp/transform/convert_for_loops_serial.cc b/src/tir/transforms/convert_for_loops_serial.cc similarity index 72% rename from src/tir/usmp/transform/convert_for_loops_serial.cc rename to src/tir/transforms/convert_for_loops_serial.cc index f71a09430762..d01ae8a45113 100644 --- a/src/tir/usmp/transform/convert_for_loops_serial.cc +++ b/src/tir/transforms/convert_for_loops_serial.cc @@ -18,18 +18,17 @@ */ /*! - * \file tir/analysis/usmp/convert_for_loops_serial.cc + * \file tir/transforms/convert_for_loops_serial.cc * \brief Convert all for loops to serial for lesser memory consumption */ #include #include #include #include -#include +#include namespace tvm { namespace tir { -namespace usmp { class ForLoopSerialConverter : public StmtExprMutator { public: @@ -52,11 +51,25 @@ Stmt ForLoopSerialConverter::operator()(const PrimFunc& func) { return this->VisitStmt(func->body); } -Stmt ConvertForLoopsToSerial(const PrimFunc& func) { return ForLoopSerialConverter()(func); } +PrimFunc ConvertForLoopsToSerial(PrimFunc func) { + PrimFuncNode* fptr = func.CopyOnWrite(); + fptr->body = ForLoopSerialConverter()(func); + return func; +} + +namespace transform { + +Pass ConvertForLoopsToSerial() { + auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) { + return ConvertForLoopsToSerial(std::move(f)); + }; + return CreatePrimFuncPass(pass_func, 0, "tir.ConvertForLoopsToSerial", {}); +} + +TVM_REGISTER_GLOBAL("tir.transform.ConvertForLoopsToSerial") + .set_body_typed(ConvertForLoopsToSerial); -TVM_REGISTER_GLOBAL("tir.usmp.transform.for_loop_serial_converter") - .set_body_typed([](PrimFunc func) { return (ConvertForLoopsToSerial(func)); }); +} // namespace transform -} // namespace usmp } // namespace tir } // namespace tvm diff --git a/tests/python/unittest/test_tir_usmp_transform_convert_for_loops_serial.py b/tests/python/unittest/test_tir_transform_convert_for_loops_serial.py similarity index 93% rename from tests/python/unittest/test_tir_usmp_transform_convert_for_loops_serial.py rename to tests/python/unittest/test_tir_transform_convert_for_loops_serial.py index 1eb64227602e..272e0d45410f 100644 --- a/tests/python/unittest/test_tir_usmp_transform_convert_for_loops_serial.py +++ b/tests/python/unittest/test_tir_transform_convert_for_loops_serial.py @@ -47,13 +47,15 @@ def fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast_2(placeholder_30: ty def test_nn_conv2d_add_fixed_point_multiply_clip_cast_cast_2(): primfunc = fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast_2 - primfunc = tvm.tir.usmp.transform.for_loop_serial_converter(primfunc) + mod = tvm.IRModule.from_expr(primfunc) + mod = tvm.tir.transform.ConvertForLoopsToSerial()(mod) def verify_serial_loops(stmt): if isinstance(stmt, tvm.tir.For): assert stmt.kind == tvm.tir.ForKind.SERIAL - stmt_functor.post_order_visit(primfunc.body, verify_serial_loops) + for _, primfunc in mod.functions.items(): + stmt_functor.post_order_visit(primfunc.body, verify_serial_loops) if __name__ == "__main__": From a424b9ad5af35aa759ad79aab741f890f22ff639 Mon Sep 17 00:00:00 2001 From: Manupa Karunaratne Date: Tue, 20 Jul 2021 10:44:35 +0100 Subject: [PATCH 4/5] [TIR][USMP] Add a parallel to serial for loop converter pass * fixed docstring Change-Id: I73bb9867fe2ed6a86f65666493c5c6e3edf87b49 --- python/tvm/tir/transform/transform.py | 11 +++-------- 1 file changed, 3 insertions(+), 8 deletions(-) diff --git a/python/tvm/tir/transform/transform.py b/python/tvm/tir/transform/transform.py index 53dead3f83c3..55c854318632 100644 --- a/python/tvm/tir/transform/transform.py +++ b/python/tvm/tir/transform/transform.py @@ -681,16 +681,11 @@ def MergeDynamicSharedMemoryAllocations(): def ConvertForLoopsToSerial(): - """Convert Parallel For Loop to Serial. - - Parameters - ---------- - func: tvm.tir.PrimFunc - The function to be converted. + """Convert Parallel For Loops to Serial For Loops. Returns ------- - tvm.tir.PrimFunc - converted function + fpass : tvm.transform.Pass + The result pass """ return _ffi_api.ConvertForLoopsToSerial() From 26571f826796cf6f7f45c680c9aea68eb8fdd24b Mon Sep 17 00:00:00 2001 From: Manupa Karunaratne Date: Fri, 6 Aug 2021 17:16:06 +0100 Subject: [PATCH 5/5] [TIR][USMP] Add a parallel to serial for loop converter pass * fixed mypy lint error Change-Id: I226ef27d5536674fbe4b2d2c6ff47b8cb3b41431 --- python/tvm/tir/transform/transform.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/tvm/tir/transform/transform.py b/python/tvm/tir/transform/transform.py index 55c854318632..f52b21af666a 100644 --- a/python/tvm/tir/transform/transform.py +++ b/python/tvm/tir/transform/transform.py @@ -688,4 +688,4 @@ def ConvertForLoopsToSerial(): fpass : tvm.transform.Pass The result pass """ - return _ffi_api.ConvertForLoopsToSerial() + return _ffi_api.ConvertForLoopsToSerial() # type: ignore