From da226817186490d403aa9a5e839bffadc0bb0fcb Mon Sep 17 00:00:00 2001 From: Zihao Date: Thu, 16 Mar 2023 02:51:21 -0700 Subject: [PATCH 1/7] upd --- include/tvm/tir/schedule/schedule.h | 10 +- python/tvm/tir/schedule/schedule.py | 74 ++++++++++- src/tir/schedule/concrete_schedule.cc | 8 ++ src/tir/schedule/concrete_schedule.h | 1 + src/tir/schedule/primitive.h | 10 ++ src/tir/schedule/primitive/block_annotate.cc | 116 ++++++++++++++++ src/tir/schedule/schedule.cc | 2 + src/tir/schedule/traced_schedule.cc | 10 ++ src/tir/schedule/traced_schedule.h | 1 + src/tir/schedule/transform.cc | 10 ++ src/tir/schedule/transform.h | 12 +- .../unittest/test_tir_schedule_set_dtype.py | 125 ++++++++++++++++++ 12 files changed, 375 insertions(+), 4 deletions(-) create mode 100644 tests/python/unittest/test_tir_schedule_set_dtype.py diff --git a/include/tvm/tir/schedule/schedule.h b/include/tvm/tir/schedule/schedule.h index 22febfdfedec..c58283b8ae11 100644 --- a/include/tvm/tir/schedule/schedule.h +++ b/include/tvm/tir/schedule/schedule.h @@ -584,13 +584,21 @@ class ScheduleNode : public runtime::Object { virtual void StorageAlign(const BlockRV& block_rv, int buffer_index, int axis, int factor, int offset) = 0; /*! - * \brief Set the storage scope of a buffer, where the buffer is specified by the a block and a + * \brief Set the storage scope of a buffer, where the buffer is specified by a block and a * write-index * \param block_rv The producer block of the buffer * \param buffer_index The index of the buffer in block's write region * \param storage_scope The storage scope to be set */ virtual void SetScope(const BlockRV& block_rv, int buffer_index, const String& storage_scope) = 0; + /*! + * \brief Set the data type of a buffer, where the buffer is specified by a block and a + * write-index + * \param block_rv The producer block of the buffer + * \param buffer_index the index of the buffer in block's write region + * \param dtype The data type to be set + */ + virtual void SetDType(const BlockRV& block_rv, int buffer_index, const String& dtype) = 0; /******** Schedule: Blockize & Tensorize ********/ /*! * \brief Convert the subtree rooted at a specific loop into a block. diff --git a/python/tvm/tir/schedule/schedule.py b/python/tvm/tir/schedule/schedule.py index 896e2fc48e72..b95d03d11142 100644 --- a/python/tvm/tir/schedule/schedule.py +++ b/python/tvm/tir/schedule/schedule.py @@ -2322,7 +2322,7 @@ def after_storage_align(a: T.handle, c: T.handle) -> None: @type_checked def set_scope(self, block: Union[BlockRV, str], buffer_index: int, storage_scope: str) -> None: """Set the storage scope of a buffer, where the buffer is - specified by the a block and a write-index + specified by the a block and a write-index. Parameters ---------- @@ -2391,6 +2391,78 @@ def after_set_scope( self, block, buffer_index, storage_scope ) + @type_checked + def set_dtype(self, block: Union[BlockRV, str], buffer_index: int, dtype: str) -> None: + """Set the data type of a buffer, where the buffer is + specified by the a block and write-index. + + Parameters + ---------- + block : Union[BlockRV, str] + The producer block of the buffer + buffer_index : int + The index of the buffer in block's write region + dtype : str + The data type to be set + + Examples + -------- + + Before set_dtype, in TensorIR, the IR is: + + .. code-block:: python + + @T.prim_func + def before_set_dtype( + A: T.Buffer((128, 128), "float32"), C: T.Buffer((128, 128), "float32") + ) -> None: + B = T.alloc_buffer((128, 128), dtype="float32") + + for i, j in T.grid(128, 128): + with T.block("B"): + vi, vj = T.axis.remap("SS", [i, j]) + B[vi, vj] = A[vi, vj] * 2.0 + for i, j in T.grid(128, 128): + with T.block("C"): + vi, vj = T.axis.remap("SS", [i, j] + C[vi, vj] = B[vi, vj] + 1.0 + + Create the schedule and do set_dtype: + + .. code-block:: python + + sch = tir.Schedule(before_set_dtype) + sch.set_dtype("B", buffer_index=0, dtype="float16") + print(sch.mod["main"].script()) + + After applying set_dtype, the IR becomes: + + .. code-block:: python + + @T.prim_func + def after_set_dtype( + A: T.Buffer((128, 128), "float32"), C: T.Buffer((128, 128), "float32") + ) -> None: + B = T.alloc_buffer((128, 128), dtype="float16") + + for i, j in T.grid(128, 128): + with T.block("B"): + vi, vj = T.axis.remap("SS", [i, j]) + B[vi, vj] = T.cast(A[vi, vj] * 2.0, "float16") + for i, j in T.grid(128, 128): + with T.block("C"): + vi, vj = T.axis.remap("SS", [i, j] + C[vi, vj] = T.cast(B[vi, vj], "float32") + 1.0 + + Note + ---- + set_dtype requires the buffer to be an intermediate buffer defined via `alloc_buffer`. + """ + block = self._normalize_block_arg(block) + _ffi_api.ScheduleSetDType( # type: ignore # pylint: disable=no-member + self, block, buffer_index, dtype + ) + ########## Schedule: Blockize & Tensorize ########## @type_checked diff --git a/src/tir/schedule/concrete_schedule.cc b/src/tir/schedule/concrete_schedule.cc index 5a9dab4854bd..14291b2403e0 100644 --- a/src/tir/schedule/concrete_schedule.cc +++ b/src/tir/schedule/concrete_schedule.cc @@ -701,6 +701,14 @@ void ConcreteScheduleNode::SetScope(const BlockRV& block_rv, int buffer_index, this->state_->DebugVerify(); } +void ConcreteScheduleNode::SetDType(const BlockRV& block_rv, int buffer_index, + const String& dtype) { + TVM_TIR_SCHEDULE_BEGIN(); + tir::SetDType(state_, this->GetSRef(block_rv), buffer_index, dtype); + TVM_TIR_SCHEDULE_END("set-dtype", this->error_render_level_); + this->state_->DebugVerify(); +} + /******** Schedule: Reduction ********/ BlockRV ConcreteScheduleNode::DecomposeReduction(const BlockRV& block_rv, const LoopRV& loop_rv) { diff --git a/src/tir/schedule/concrete_schedule.h b/src/tir/schedule/concrete_schedule.h index 82ac9f913374..c0d5743eecdd 100644 --- a/src/tir/schedule/concrete_schedule.h +++ b/src/tir/schedule/concrete_schedule.h @@ -141,6 +141,7 @@ class ConcreteScheduleNode : public ScheduleNode { void StorageAlign(const BlockRV& block_rv, int buffer_index, int axis, int factor, int offset) override; void SetScope(const BlockRV& block_rv, int buffer_index, const String& storage_scope) override; + void SetDType(const BlockRV& block_rv, int buffer_index, const String& dtype) override; /******** Schedule: Blockize & Tensorize ********/ BlockRV Blockize(const LoopRV& loop_rv, bool preserve_unit_iters) override; void Tensorize(const BlockRV& block_rv, const String& intrin, bool preserve_unit_iters) override; diff --git a/src/tir/schedule/primitive.h b/src/tir/schedule/primitive.h index 563864229a26..db4c754a91e3 100644 --- a/src/tir/schedule/primitive.h +++ b/src/tir/schedule/primitive.h @@ -470,6 +470,16 @@ TVM_DLL void StorageAlign(ScheduleState self, const StmtSRef& block_sref, int bu */ TVM_DLL void SetScope(ScheduleState self, const StmtSRef& block_sref, int buffer_index, const String& storage_scope); +/*! + * \brief Set the data type of a buffer, where the buffer is specified by a block and a + * write-index + * \param self The state of the schedule + * \param block_sref The sref of the producer block of the buffer + * \param buffer_index The index of the buffer in block's write region + * \param dtype The data type to be set + */ +TVM_DLL void SetDType(ScheduleState self, const StmtSRef& block_sref, int buffer_index, + const String& dtype); /*! * \brief Set the axis separator of a buffer, where the buffer is specified by a block and a read * or write index diff --git a/src/tir/schedule/primitive/block_annotate.cc b/src/tir/schedule/primitive/block_annotate.cc index 0912e36836e3..85f81486acdf 100644 --- a/src/tir/schedule/primitive/block_annotate.cc +++ b/src/tir/schedule/primitive/block_annotate.cc @@ -17,6 +17,7 @@ * under the License. */ #include "../utils.h" +#include namespace tvm { namespace tir { @@ -297,6 +298,93 @@ void SetScope(ScheduleState self, const StmtSRef& block_sref, int buffer_index, self->Replace(alloc_site_sref, new_block, block_reuse_map); } +/*! + * \brief A helper mutator which recursively mutates the old buffer's data type, inserts data type + * conversions, and collecte the block sref reuse information for the following replacement. + */ +class DTypeMutator : private ReplaceBufferMutator { + public: + /*! + * \param allocate_site The block where `old_buffer` was allocated. + * \param old_buffer The old buffer + * \param target_dtype The data type to be set + * \param block_sref_reuse The block sref reuse map to be updated + * \return The new block after the mutation + */ + static Block Mutate(const Block& allocate_site, const Buffer& old_buffer, const DataType& dtype, + Map* block_sref_reuse) { + Buffer new_buffer = WithDType(old_buffer, dtype); + DTypeMutator mutator(old_buffer, new_buffer, dtype, block_sref_reuse); + Stmt new_block = mutator.VisitStmt(allocate_site); + return Downcast(new_block); + } + + private: + DTypeMutator(const Buffer& old_buffer, Buffer new_buffer, const DataType& dtype, + Map* block_sref_reuse) + : ReplaceBufferMutator(old_buffer, std::move(new_buffer), block_sref_reuse), + src_dtype_(old_buffer->dtype), + tgt_dtype_(dtype) {} + + MatchBufferRegion VisitMatchBufferRegion(const MatchBufferRegion& match_buffer) final { + auto it = buffer_var_map_.find(match_buffer->source->buffer->data.get()); + if (it != buffer_var_map_.end()) { + Buffer new_target_buffer = WithDType(match_buffer->buffer, it->second->dtype); + buffer_var_map_[match_buffer->buffer->data.get()] = new_target_buffer; + return MatchBufferRegion(new_target_buffer, + BufferRegion(it->second, match_buffer->source->region)); + } else { + return match_buffer; + } + } + + Stmt VisitStmt_(const BufferStoreNode* op) final { + BufferStore node = Downcast(StmtExprMutator::VisitStmt_(op)); + auto it = buffer_var_map_.find(node->buffer->data.get()); + if (it != buffer_var_map_.end()) { + node.CopyOnWrite()->buffer = it->second; + node.CopyOnWrite()->value = Cast(tgt_dtype_, node->value); + } + return node; + } + + PrimExpr VisitExpr_(const BufferLoadNode* op) final { + BufferLoad node = Downcast(StmtExprMutator::VisitExpr_(op)); + auto it = buffer_var_map_.find(node->buffer->data.get()); + if (it != buffer_var_map_.end()) { + return Cast(src_dtype_, BufferLoad(it->second, node->indices)); + } + return node; + } + + DataType src_dtype_, tgt_dtype_; +}; + +void SetDType(ScheduleState self, const StmtSRef& block_sref, int buffer_index, + const String& dtype) { + const BlockNode* block = TVM_SREF_TO_BLOCK(block_sref); + Buffer buffer = + GetNthAccessBuffer(self, GetRef(block), buffer_index, BufferIndexType::kWrite); + DataType target_dtype(runtime::String2DLDataType(dtype)); + + // Step 1. If `dtype` equals the original data type, just return. + if (buffer->dtype == target_dtype) { + return; + } + + // Step 2. Get the allocation site of the target buffer. + StmtSRef alloc_site_sref = + NonAllocatedBufferError::CheckAndGetBufferAllocationSite(self->mod, block_sref, buffer); + const BlockNode* alloc_site = TVM_SREF_TO_BLOCK(alloc_site_sref); + + // Step 3. Recursively replace old buffer to a new buffer, where the new buffer has the given + // dtype, and insert data type conversions. + Map block_reuse_map; + Block new_block = + DTypeMutator::Mutate(GetRef(alloc_site), buffer, target_dtype, &block_reuse_map); + self->Replace(alloc_site_sref, new_block, block_reuse_map); +} + /******** InstructionKind Registration ********/ struct StorageAlignTraits : public UnpackedInstTraits { @@ -356,8 +444,36 @@ struct SetScopeTraits : public UnpackedInstTraits { friend struct ::tvm::tir::UnpackedInstTraits; }; +struct SetDTypeTraits : public UnpackedInstTraits { + static constexpr const char* kName = "SetDType"; + static constexpr bool kIsPure = false; + + private: + static constexpr size_t kNumInputs = 1; + static constexpr size_t kNumAttrs = 2; + static constexpr size_t kNumDecisions = 0; + + static void UnpackedApplyToSchedule(Schedule sch, BlockRV block_rv, Integer buffer_index, + String dtype) { + return sch->SetDType(block_rv, buffer_index->value, dtype); + } + + static String UnpackedAsPython(Array outputs, String block_rv, Integer buffer_index, + String dtype) { + PythonAPICall py("set_dtype"); + py.Input("block", block_rv); + py.Input("buffer_index", buffer_index); + py.Input("dtype", dtype); + return py.Str(); + } + + template + friend struct ::tvm::tir::UnpackedInstTraits; +}; + TVM_REGISTER_INST_KIND_TRAITS(StorageAlignTraits); TVM_REGISTER_INST_KIND_TRAITS(SetScopeTraits); +TVM_REGISTER_INST_KIND_TRAITS(SetDTypeTraits); } // namespace tir } // namespace tvm diff --git a/src/tir/schedule/schedule.cc b/src/tir/schedule/schedule.cc index cb8b5a1d7787..ebe9b90d7604 100644 --- a/src/tir/schedule/schedule.cc +++ b/src/tir/schedule/schedule.cc @@ -211,6 +211,8 @@ TVM_REGISTER_GLOBAL("tir.schedule.ScheduleStorageAlign") .set_body_method(&ScheduleNode::StorageAlign); TVM_REGISTER_GLOBAL("tir.schedule.ScheduleSetScope") .set_body_method(&ScheduleNode::SetScope); +TVM_REGISTER_GLOBAL("tir.schedule.ScheduleSetDType") + .set_body_method(&ScheduleNode::SetDType); /******** (FFI) Blockize & Tensorize ********/ TVM_REGISTER_GLOBAL("tir.schedule.ScheduleBlockize") .set_body_method(&ScheduleNode::Blockize); diff --git a/src/tir/schedule/traced_schedule.cc b/src/tir/schedule/traced_schedule.cc index a5cb66a0cb44..71b658bc8e93 100644 --- a/src/tir/schedule/traced_schedule.cc +++ b/src/tir/schedule/traced_schedule.cc @@ -475,6 +475,16 @@ void TracedScheduleNode::SetScope(const BlockRV& block_rv, int buffer_index, /*outputs=*/{})); } +void TracedScheduleNode::SetDType(const BlockRV& block_rv, int buffer_index, const String& dtype) { + ConcreteScheduleNode::SetDType(block_rv, buffer_index, dtype); + static const InstructionKind& kind = InstructionKind::Get("SetDType"); + trace_->Append(/*inst=*/Instruction( + /*kind=*/kind, + /*inputs=*/{block_rv}, + /*attrs=*/{Integer(buffer_index), dtype}, + /*outputs=*/{})); +} + /******** Schedule: Blockize & Tensorize ********/ BlockRV TracedScheduleNode::Blockize(const LoopRV& loop_rv, bool preserve_unit_iters) { diff --git a/src/tir/schedule/traced_schedule.h b/src/tir/schedule/traced_schedule.h index 1fcba9806380..618238f4a2bf 100644 --- a/src/tir/schedule/traced_schedule.h +++ b/src/tir/schedule/traced_schedule.h @@ -100,6 +100,7 @@ class TracedScheduleNode : public ConcreteScheduleNode { void StorageAlign(const BlockRV& block_rv, int buffer_index, int axis, int factor, int offset) final; void SetScope(const BlockRV& block_rv, int buffer_index, const String& storage_scope) final; + void SetDType(const BlockRV& block_rv, int buffer_index, const String& dtype) final; /******** Schedule: Blockize & Tensorize ********/ BlockRV Blockize(const LoopRV& loop_rv, bool preserve_unit_iters) final; void Tensorize(const BlockRV& block_rv, const String& intrin, bool preserve_unit_iters) final; diff --git a/src/tir/schedule/transform.cc b/src/tir/schedule/transform.cc index e91c5d142c04..baa7f44bbcf2 100644 --- a/src/tir/schedule/transform.cc +++ b/src/tir/schedule/transform.cc @@ -43,6 +43,16 @@ Buffer WithScope(const Buffer& buffer, const String& scope) { return Buffer(new_buffer); } +Buffer WithDType(const Buffer& buffer, const DataType& dtype) { + ObjectPtr new_buffer = make_object(*buffer.get()); + new_buffer->dtype = dtype; + const auto* ptr_type = TVM_TYPE_AS(buffer->data->type_annotation, PointerTypeNode); + new_buffer->data = + Var(buffer->data->name_hint, PointerType(PrimType(dtype), ptr_type->storage_scope)); + new_buffer->name = buffer->name; + return Buffer(new_buffer); +} + Array ReplaceBuffer(Array regions, const Buffer& source, const Buffer& target) { regions.MutateByApply([&source, &target](BufferRegion region) -> BufferRegion { diff --git a/src/tir/schedule/transform.h b/src/tir/schedule/transform.h index 3593d6b9a444..d2412436c7fb 100644 --- a/src/tir/schedule/transform.h +++ b/src/tir/schedule/transform.h @@ -53,6 +53,14 @@ Block WithAnnotation(const BlockNode* block, const String& attr_key, const Objec */ Buffer WithScope(const Buffer& buffer, const String& scope); +/*! + * \brief Create a new buffer by changint the data type. + * \param buffer The given buffer. + * \param scope The target data type. + * \return The new buffer with target data type. + */ +Buffer WithDType(const Buffer& buffer, const DataType& dtype); + /*! * \brief Replaces the buffer within the specific sequence of regions * \param regions The regions whose buffers are to be replaced @@ -131,9 +139,9 @@ class ReplaceBufferMutator : public StmtExprMutator { return node; } - Stmt VisitStmt_(const BufferStoreNode* op) final; + Stmt VisitStmt_(const BufferStoreNode* op) override; - PrimExpr VisitExpr_(const BufferLoadNode* op) final; + PrimExpr VisitExpr_(const BufferLoadNode* op) override; virtual MatchBufferRegion VisitMatchBufferRegion(const MatchBufferRegion& match_buffer); diff --git a/tests/python/unittest/test_tir_schedule_set_dtype.py b/tests/python/unittest/test_tir_schedule_set_dtype.py new file mode 100644 index 000000000000..7fefc22fdea4 --- /dev/null +++ b/tests/python/unittest/test_tir_schedule_set_dtype.py @@ -0,0 +1,125 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=missing-function-docstring,missing-module-docstring + +import pytest +import tvm +import tvm.testing +from tvm import tir +from tvm.script import tir as T +from tvm.tir.schedule.testing import verify_trace_roundtrip + +# fmt: off +# pylint: disable=no-member,invalid-name,unused-variable,unexpected-keyword-arg + +@T.prim_func +def element_wise(A: T.Buffer((128, 128), "float32"), C: T.Buffer((128, 128), "float32")) -> None: + B = T.alloc_buffer((128, 128), dtype="float32") + + for i, j in T.grid(128, 128): + with T.block("B"): + vi, vj = T.axis.remap("SS", [i, j]) + B[vi, vj] = A[vi, vj] * 2.0 + for i, j in T.grid(128, 128): + with T.block("C"): + vi, vj = T.axis.remap("SS", [i, j]) + C[vi, vj] = B[vi, vj] + 1.0 + +@T.prim_func +def element_wise_set_dtype(A: T.Buffer((128, 128), "float32"), C: T.Buffer((128, 128), "float32")): + B = T.alloc_buffer((128, 128), "float16") + for i, j in T.grid(128, 128): + with T.block("B"): + vi, vj = T.axis.remap("SS", [i, j]) + T.reads(A[vi, vj]) + T.writes(B[vi, vj]) + B[vi, vj] = T.cast(A[vi, vj] * 2.0, "float16") + for i, j in T.grid(128, 128): + with T.block("C"): + vi, vj = T.axis.remap("SS", [i, j]) + T.reads(B[vi, vj]) + T.writes(C[vi, vj]) + C[vi, vj] = T.cast(B[vi, vj], "float32") + 1.0 + +@T.prim_func +def element_wise_subregion_match(A: T.Buffer((128, 128), "float32"), C: T.Buffer((128, 128), "float32")) -> None: + B = T.alloc_buffer((128, 128), dtype="float32") + + for i, j in T.grid(128, 128): + with T.block("B"): + vi, vj = T.axis.remap("SS", [i, j]) + B_subregion0 = T.match_buffer(B[vi, vj], [], offset_factor=1) + B_subregion0[()] = A[vi, vj] * 2.0 + for i, j in T.grid(128, 128): + with T.block("C"): + vi, vj = T.axis.remap("SS", [i, j]) + B_subregion1 = T.match_buffer(B[vi, vj], [], offset_factor=1) + C[vi, vj] = B_subregion1[()] + 1.0 + + +@T.prim_func +def element_wise_subregion_match_set_dtype(A: T.Buffer((128, 128), "float32"), C: T.Buffer((128, 128), "float32")) -> None: + B = T.alloc_buffer((128, 128), "float16") + for i, j in T.grid(128, 128): + with T.block("B"): + vi, vj = T.axis.remap("SS", [i, j]) + T.reads(A[vi, vj]) + T.writes(B[vi, vj]) + B_subregion0 = T.match_buffer(B[vi, vj], (), "float16", offset_factor=1) + B_subregion0[()] = T.cast(A[vi, vj] * 2.0, "float16") + for i, j in T.grid(128, 128): + with T.block("C"): + vi, vj = T.axis.remap("SS", [i, j]) + T.reads(B[vi, vj]) + T.writes(C[vi, vj]) + B_subregion1 = T.match_buffer(B[vi, vj], (), "float16", offset_factor=1) + C[vi, vj] = T.cast(B_subregion1[()], "float32") + 1.0 + + +use_block_name = tvm.testing.parameter(by_dict={"block_obj": False, "block_name": True}) + +def test_set_dtype(use_block_name): + func = element_wise + sch = tir.Schedule(func, debug_mask="all") + sch.set_dtype("B" if use_block_name else sch.get_block("B"), 0, "float16") + tvm.ir.assert_structural_equal(element_wise_set_dtype, sch.mod["main"]) + verify_trace_roundtrip(sch=sch, mod=func) + +def test_set_dtype_fail_on_output_buffer(use_block_name): + func = element_wise + sch = tir.Schedule(func, debug_mask='all') + with pytest.raises(tvm.tir.ScheduleError): + sch.set_dtype('C' if use_block_name else sch.get_block("C"), 0, "float16") + +def test_set_dtype_fail_on_index_out_of_bound(): + func = element_wise + sch = tir.Schedule(func, debug_mask='all') + with pytest.raises(tvm.tir.ScheduleError): + sch.set_dtype(sch.get_block("B"), 1, "float64") + with pytest.raises(tvm.tir.ScheduleError): + sch.set_dtype(sch.get_block("B"), -1, "float64") + +def test_set_dtype_subregion(): + func = element_wise_subregion_match + sch = tir.Schedule(func, debug_mask='all') + sch.set_dtype(sch.get_block("B"), 0, "float16") + tvm.ir.assert_structural_equal(element_wise_subregion_match_set_dtype, sch.mod["main"]) + verify_trace_roundtrip(sch=sch, mod=func) + + +if __name__ == "__main__": + tvm.testing.main() From 4fa8d18bc7d7ef561098fdf3258f22734e4e7d5b Mon Sep 17 00:00:00 2001 From: Zihao Date: Thu, 16 Mar 2023 02:52:11 -0700 Subject: [PATCH 2/7] add warning --- python/tvm/tir/schedule/schedule.py | 1 + 1 file changed, 1 insertion(+) diff --git a/python/tvm/tir/schedule/schedule.py b/python/tvm/tir/schedule/schedule.py index b95d03d11142..84f9942516e0 100644 --- a/python/tvm/tir/schedule/schedule.py +++ b/python/tvm/tir/schedule/schedule.py @@ -2457,6 +2457,7 @@ def after_set_dtype( Note ---- set_dtype requires the buffer to be an intermediate buffer defined via `alloc_buffer`. + This schedule primitive might influence the computation result because of type conversion. """ block = self._normalize_block_arg(block) _ffi_api.ScheduleSetDType( # type: ignore # pylint: disable=no-member From d3425571ccad83e9009652bd90a2c009b4bbd4cf Mon Sep 17 00:00:00 2001 From: Zihao Date: Thu, 16 Mar 2023 03:15:56 -0700 Subject: [PATCH 3/7] lint --- tests/python/unittest/test_tir_schedule_set_dtype.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/python/unittest/test_tir_schedule_set_dtype.py b/tests/python/unittest/test_tir_schedule_set_dtype.py index 7fefc22fdea4..1a8181b1d920 100644 --- a/tests/python/unittest/test_tir_schedule_set_dtype.py +++ b/tests/python/unittest/test_tir_schedule_set_dtype.py @@ -53,7 +53,7 @@ def element_wise_set_dtype(A: T.Buffer((128, 128), "float32"), C: T.Buffer((128, vi, vj = T.axis.remap("SS", [i, j]) T.reads(B[vi, vj]) T.writes(C[vi, vj]) - C[vi, vj] = T.cast(B[vi, vj], "float32") + 1.0 + C[vi, vj] = T.cast(B[vi, vj], "float32") + 1.0 @T.prim_func def element_wise_subregion_match(A: T.Buffer((128, 128), "float32"), C: T.Buffer((128, 128), "float32")) -> None: From c2c0028fefd367e33296fbeff69c50e80372182c Mon Sep 17 00:00:00 2001 From: Zihao Date: Thu, 16 Mar 2023 08:22:34 -0700 Subject: [PATCH 4/7] cpplint --- src/tir/schedule/primitive/block_annotate.cc | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/tir/schedule/primitive/block_annotate.cc b/src/tir/schedule/primitive/block_annotate.cc index 85f81486acdf..764e0894fc63 100644 --- a/src/tir/schedule/primitive/block_annotate.cc +++ b/src/tir/schedule/primitive/block_annotate.cc @@ -16,9 +16,10 @@ * specific language governing permissions and limitations * under the License. */ -#include "../utils.h" #include +#include "../utils.h" + namespace tvm { namespace tir { From f0f6c9bb77502cd0c7def6baa950d11d811cf5eb Mon Sep 17 00:00:00 2001 From: Zihao Date: Sun, 19 Mar 2023 21:21:15 -0700 Subject: [PATCH 5/7] rename to unsafe_set_dtype --- python/tvm/tir/schedule/schedule.py | 10 ++++++---- tests/python/unittest/test_tir_schedule_set_dtype.py | 10 +++++----- 2 files changed, 11 insertions(+), 9 deletions(-) diff --git a/python/tvm/tir/schedule/schedule.py b/python/tvm/tir/schedule/schedule.py index 84f9942516e0..889d08b94dc3 100644 --- a/python/tvm/tir/schedule/schedule.py +++ b/python/tvm/tir/schedule/schedule.py @@ -2384,7 +2384,7 @@ def after_set_scope( Note ---- - Set_scope requires the buffer to be an intermediate buffer defined via `alloc_buffer`. + `set_scope` requires the buffer to be an intermediate buffer defined via `alloc_buffer`. """ block = self._normalize_block_arg(block) _ffi_api.ScheduleSetScope( # type: ignore # pylint: disable=no-member @@ -2392,10 +2392,13 @@ def after_set_scope( ) @type_checked - def set_dtype(self, block: Union[BlockRV, str], buffer_index: int, dtype: str) -> None: + def unsafe_set_dtype(self, block: Union[BlockRV, str], buffer_index: int, dtype: str) -> None: """Set the data type of a buffer, where the buffer is specified by the a block and write-index. + This schedule primitive is unsafe and may change the correctness of program because of + type conversion. + Parameters ---------- block : Union[BlockRV, str] @@ -2456,8 +2459,7 @@ def after_set_dtype( Note ---- - set_dtype requires the buffer to be an intermediate buffer defined via `alloc_buffer`. - This schedule primitive might influence the computation result because of type conversion. + `set_dtype` requires the buffer to be an intermediate buffer defined via `alloc_buffer`. """ block = self._normalize_block_arg(block) _ffi_api.ScheduleSetDType( # type: ignore # pylint: disable=no-member diff --git a/tests/python/unittest/test_tir_schedule_set_dtype.py b/tests/python/unittest/test_tir_schedule_set_dtype.py index 1a8181b1d920..7f0900619b9b 100644 --- a/tests/python/unittest/test_tir_schedule_set_dtype.py +++ b/tests/python/unittest/test_tir_schedule_set_dtype.py @@ -95,7 +95,7 @@ def element_wise_subregion_match_set_dtype(A: T.Buffer((128, 128), "float32"), C def test_set_dtype(use_block_name): func = element_wise sch = tir.Schedule(func, debug_mask="all") - sch.set_dtype("B" if use_block_name else sch.get_block("B"), 0, "float16") + sch.unsafe_set_dtype("B" if use_block_name else sch.get_block("B"), 0, "float16") tvm.ir.assert_structural_equal(element_wise_set_dtype, sch.mod["main"]) verify_trace_roundtrip(sch=sch, mod=func) @@ -103,20 +103,20 @@ def test_set_dtype_fail_on_output_buffer(use_block_name): func = element_wise sch = tir.Schedule(func, debug_mask='all') with pytest.raises(tvm.tir.ScheduleError): - sch.set_dtype('C' if use_block_name else sch.get_block("C"), 0, "float16") + sch.unsafe_set_dtype('C' if use_block_name else sch.get_block("C"), 0, "float16") def test_set_dtype_fail_on_index_out_of_bound(): func = element_wise sch = tir.Schedule(func, debug_mask='all') with pytest.raises(tvm.tir.ScheduleError): - sch.set_dtype(sch.get_block("B"), 1, "float64") + sch.unsafe_set_dtype(sch.get_block("B"), 1, "float64") with pytest.raises(tvm.tir.ScheduleError): - sch.set_dtype(sch.get_block("B"), -1, "float64") + sch.unsafe_set_dtype(sch.get_block("B"), -1, "float64") def test_set_dtype_subregion(): func = element_wise_subregion_match sch = tir.Schedule(func, debug_mask='all') - sch.set_dtype(sch.get_block("B"), 0, "float16") + sch.unsafe_set_dtype(sch.get_block("B"), 0, "float16") tvm.ir.assert_structural_equal(element_wise_subregion_match_set_dtype, sch.mod["main"]) verify_trace_roundtrip(sch=sch, mod=func) From 1db3b2111299e8f1fbb0670570d6d6c596da823e Mon Sep 17 00:00:00 2001 From: Zihao Date: Tue, 21 Mar 2023 00:09:57 -0700 Subject: [PATCH 6/7] fix c++ api --- include/tvm/tir/schedule/schedule.h | 4 +++- python/tvm/tir/schedule/schedule.py | 4 ++-- src/tir/schedule/concrete_schedule.cc | 6 +++--- src/tir/schedule/concrete_schedule.h | 2 +- src/tir/schedule/primitive.h | 6 ++++-- src/tir/schedule/primitive/block_annotate.cc | 12 ++++++------ src/tir/schedule/schedule.cc | 4 ++-- src/tir/schedule/traced_schedule.cc | 7 ++++--- src/tir/schedule/traced_schedule.h | 2 +- 9 files changed, 26 insertions(+), 21 deletions(-) diff --git a/include/tvm/tir/schedule/schedule.h b/include/tvm/tir/schedule/schedule.h index c58283b8ae11..215e330a4c6f 100644 --- a/include/tvm/tir/schedule/schedule.h +++ b/include/tvm/tir/schedule/schedule.h @@ -594,11 +594,13 @@ class ScheduleNode : public runtime::Object { /*! * \brief Set the data type of a buffer, where the buffer is specified by a block and a * write-index + * \note This schedule primitive is unsafe and may change correctness of program because of + * type conversion, please use with caution. * \param block_rv The producer block of the buffer * \param buffer_index the index of the buffer in block's write region * \param dtype The data type to be set */ - virtual void SetDType(const BlockRV& block_rv, int buffer_index, const String& dtype) = 0; + virtual void UnsafeSetDType(const BlockRV& block_rv, int buffer_index, const String& dtype) = 0; /******** Schedule: Blockize & Tensorize ********/ /*! * \brief Convert the subtree rooted at a specific loop into a block. diff --git a/python/tvm/tir/schedule/schedule.py b/python/tvm/tir/schedule/schedule.py index 889d08b94dc3..9269acdd78ba 100644 --- a/python/tvm/tir/schedule/schedule.py +++ b/python/tvm/tir/schedule/schedule.py @@ -2397,7 +2397,7 @@ def unsafe_set_dtype(self, block: Union[BlockRV, str], buffer_index: int, dtype: specified by the a block and write-index. This schedule primitive is unsafe and may change the correctness of program because of - type conversion. + type conversion, please use with caution. Parameters ---------- @@ -2462,7 +2462,7 @@ def after_set_dtype( `set_dtype` requires the buffer to be an intermediate buffer defined via `alloc_buffer`. """ block = self._normalize_block_arg(block) - _ffi_api.ScheduleSetDType( # type: ignore # pylint: disable=no-member + _ffi_api.ScheduleUnsafeSetDType( # type: ignore # pylint: disable=no-member self, block, buffer_index, dtype ) diff --git a/src/tir/schedule/concrete_schedule.cc b/src/tir/schedule/concrete_schedule.cc index 14291b2403e0..330486b86ba2 100644 --- a/src/tir/schedule/concrete_schedule.cc +++ b/src/tir/schedule/concrete_schedule.cc @@ -701,10 +701,10 @@ void ConcreteScheduleNode::SetScope(const BlockRV& block_rv, int buffer_index, this->state_->DebugVerify(); } -void ConcreteScheduleNode::SetDType(const BlockRV& block_rv, int buffer_index, - const String& dtype) { +void ConcreteScheduleNode::UnsafeSetDType(const BlockRV& block_rv, int buffer_index, + const String& dtype) { TVM_TIR_SCHEDULE_BEGIN(); - tir::SetDType(state_, this->GetSRef(block_rv), buffer_index, dtype); + tir::UnsafeSetDType(state_, this->GetSRef(block_rv), buffer_index, dtype); TVM_TIR_SCHEDULE_END("set-dtype", this->error_render_level_); this->state_->DebugVerify(); } diff --git a/src/tir/schedule/concrete_schedule.h b/src/tir/schedule/concrete_schedule.h index c0d5743eecdd..93f094304bf4 100644 --- a/src/tir/schedule/concrete_schedule.h +++ b/src/tir/schedule/concrete_schedule.h @@ -141,7 +141,7 @@ class ConcreteScheduleNode : public ScheduleNode { void StorageAlign(const BlockRV& block_rv, int buffer_index, int axis, int factor, int offset) override; void SetScope(const BlockRV& block_rv, int buffer_index, const String& storage_scope) override; - void SetDType(const BlockRV& block_rv, int buffer_index, const String& dtype) override; + void UnsafeSetDType(const BlockRV& block_rv, int buffer_index, const String& dtype) override; /******** Schedule: Blockize & Tensorize ********/ BlockRV Blockize(const LoopRV& loop_rv, bool preserve_unit_iters) override; void Tensorize(const BlockRV& block_rv, const String& intrin, bool preserve_unit_iters) override; diff --git a/src/tir/schedule/primitive.h b/src/tir/schedule/primitive.h index db4c754a91e3..9c3540eb3d68 100644 --- a/src/tir/schedule/primitive.h +++ b/src/tir/schedule/primitive.h @@ -473,13 +473,15 @@ TVM_DLL void SetScope(ScheduleState self, const StmtSRef& block_sref, int buffer /*! * \brief Set the data type of a buffer, where the buffer is specified by a block and a * write-index + * \note This schedule primitive is unsafe and may change correctness of program because of + * type conversion, please use with caution. * \param self The state of the schedule * \param block_sref The sref of the producer block of the buffer * \param buffer_index The index of the buffer in block's write region * \param dtype The data type to be set */ -TVM_DLL void SetDType(ScheduleState self, const StmtSRef& block_sref, int buffer_index, - const String& dtype); +TVM_DLL void UnsafeSetDType(ScheduleState self, const StmtSRef& block_sref, int buffer_index, + const String& dtype); /*! * \brief Set the axis separator of a buffer, where the buffer is specified by a block and a read * or write index diff --git a/src/tir/schedule/primitive/block_annotate.cc b/src/tir/schedule/primitive/block_annotate.cc index 764e0894fc63..f1816d31330b 100644 --- a/src/tir/schedule/primitive/block_annotate.cc +++ b/src/tir/schedule/primitive/block_annotate.cc @@ -361,8 +361,8 @@ class DTypeMutator : private ReplaceBufferMutator { DataType src_dtype_, tgt_dtype_; }; -void SetDType(ScheduleState self, const StmtSRef& block_sref, int buffer_index, - const String& dtype) { +void UnsafeSetDType(ScheduleState self, const StmtSRef& block_sref, int buffer_index, + const String& dtype) { const BlockNode* block = TVM_SREF_TO_BLOCK(block_sref); Buffer buffer = GetNthAccessBuffer(self, GetRef(block), buffer_index, BufferIndexType::kWrite); @@ -445,8 +445,8 @@ struct SetScopeTraits : public UnpackedInstTraits { friend struct ::tvm::tir::UnpackedInstTraits; }; -struct SetDTypeTraits : public UnpackedInstTraits { - static constexpr const char* kName = "SetDType"; +struct UnsafeSetDTypeTraits : public UnpackedInstTraits { + static constexpr const char* kName = "UnsafeSetDType"; static constexpr bool kIsPure = false; private: @@ -456,7 +456,7 @@ struct SetDTypeTraits : public UnpackedInstTraits { static void UnpackedApplyToSchedule(Schedule sch, BlockRV block_rv, Integer buffer_index, String dtype) { - return sch->SetDType(block_rv, buffer_index->value, dtype); + return sch->UnsafeSetDType(block_rv, buffer_index->value, dtype); } static String UnpackedAsPython(Array outputs, String block_rv, Integer buffer_index, @@ -474,7 +474,7 @@ struct SetDTypeTraits : public UnpackedInstTraits { TVM_REGISTER_INST_KIND_TRAITS(StorageAlignTraits); TVM_REGISTER_INST_KIND_TRAITS(SetScopeTraits); -TVM_REGISTER_INST_KIND_TRAITS(SetDTypeTraits); +TVM_REGISTER_INST_KIND_TRAITS(UnsafeSetDTypeTraits); } // namespace tir } // namespace tvm diff --git a/src/tir/schedule/schedule.cc b/src/tir/schedule/schedule.cc index ebe9b90d7604..a3d5346f7fe1 100644 --- a/src/tir/schedule/schedule.cc +++ b/src/tir/schedule/schedule.cc @@ -211,8 +211,8 @@ TVM_REGISTER_GLOBAL("tir.schedule.ScheduleStorageAlign") .set_body_method(&ScheduleNode::StorageAlign); TVM_REGISTER_GLOBAL("tir.schedule.ScheduleSetScope") .set_body_method(&ScheduleNode::SetScope); -TVM_REGISTER_GLOBAL("tir.schedule.ScheduleSetDType") - .set_body_method(&ScheduleNode::SetDType); +TVM_REGISTER_GLOBAL("tir.schedule.ScheduleUnsafeSetDType") + .set_body_method(&ScheduleNode::UnsafeSetDType); /******** (FFI) Blockize & Tensorize ********/ TVM_REGISTER_GLOBAL("tir.schedule.ScheduleBlockize") .set_body_method(&ScheduleNode::Blockize); diff --git a/src/tir/schedule/traced_schedule.cc b/src/tir/schedule/traced_schedule.cc index 71b658bc8e93..2b3a3e54b5d3 100644 --- a/src/tir/schedule/traced_schedule.cc +++ b/src/tir/schedule/traced_schedule.cc @@ -475,9 +475,10 @@ void TracedScheduleNode::SetScope(const BlockRV& block_rv, int buffer_index, /*outputs=*/{})); } -void TracedScheduleNode::SetDType(const BlockRV& block_rv, int buffer_index, const String& dtype) { - ConcreteScheduleNode::SetDType(block_rv, buffer_index, dtype); - static const InstructionKind& kind = InstructionKind::Get("SetDType"); +void TracedScheduleNode::UnsafeSetDType(const BlockRV& block_rv, int buffer_index, + const String& dtype) { + ConcreteScheduleNode::UnsafeSetDType(block_rv, buffer_index, dtype); + static const InstructionKind& kind = InstructionKind::Get("UnsafeSetDType"); trace_->Append(/*inst=*/Instruction( /*kind=*/kind, /*inputs=*/{block_rv}, diff --git a/src/tir/schedule/traced_schedule.h b/src/tir/schedule/traced_schedule.h index 618238f4a2bf..e59dc564aadb 100644 --- a/src/tir/schedule/traced_schedule.h +++ b/src/tir/schedule/traced_schedule.h @@ -100,7 +100,7 @@ class TracedScheduleNode : public ConcreteScheduleNode { void StorageAlign(const BlockRV& block_rv, int buffer_index, int axis, int factor, int offset) final; void SetScope(const BlockRV& block_rv, int buffer_index, const String& storage_scope) final; - void SetDType(const BlockRV& block_rv, int buffer_index, const String& dtype) final; + void UnsafeSetDType(const BlockRV& block_rv, int buffer_index, const String& dtype) final; /******** Schedule: Blockize & Tensorize ********/ BlockRV Blockize(const LoopRV& loop_rv, bool preserve_unit_iters) final; void Tensorize(const BlockRV& block_rv, const String& intrin, bool preserve_unit_iters) final; From 569fc4d4bbedff421adf17c4d78035514c6c07cb Mon Sep 17 00:00:00 2001 From: Zihao Date: Tue, 21 Mar 2023 09:29:00 -0700 Subject: [PATCH 7/7] fix --- src/tir/schedule/primitive/block_annotate.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/tir/schedule/primitive/block_annotate.cc b/src/tir/schedule/primitive/block_annotate.cc index f1816d31330b..3f1789b3d6e6 100644 --- a/src/tir/schedule/primitive/block_annotate.cc +++ b/src/tir/schedule/primitive/block_annotate.cc @@ -461,7 +461,7 @@ struct UnsafeSetDTypeTraits : public UnpackedInstTraits { static String UnpackedAsPython(Array outputs, String block_rv, Integer buffer_index, String dtype) { - PythonAPICall py("set_dtype"); + PythonAPICall py("unsafe_set_dtype"); py.Input("block", block_rv); py.Input("buffer_index", buffer_index); py.Input("dtype", dtype);