diff --git a/include/tvm/ir/attrs.h b/include/tvm/ir/attrs.h index f6c15f9590df..9a2468714962 100644 --- a/include/tvm/ir/attrs.h +++ b/include/tvm/ir/attrs.h @@ -382,6 +382,47 @@ inline TFunc WithAttrs(TFunc input, Map attrs) { return input; } +/*! + * \brief Copy the function or module, but removes the specified + * attribute. + * + * \param input The thing to annotate (BaseFunc or IRModule) + * \param attr_key The attribute key. + * + * \tparam TFunc The corresponding function or module type. + * + * \returns The new function or module with removed attribute. + * + * \note This function performs copy on write optimization for func and module. + * If we move a uniquely referenced func or module into WithoutAttr, + * then no additional copy will be performed. + * + * This is also why we make it as a function instead of a member function + * and why we pass by value in the first argument. + * + * \code + * + * // Recommended way to trigger copy on write + * func = WithoutAttr(std::move(func), "key1"); + * func = WithoutAttr(std::move(func), "key2"); + * + * \endcode + */ +template +inline TFunc WithoutAttr(TFunc input, const std::string& attr_key) { + using TNode = typename TFunc::ContainerType; + static_assert(TNode::_type_final, "Can only operate on the leaf nodes"); + + if (input->attrs.defined()) { + TNode* node = input.CopyOnWrite(); + node->attrs.CopyOnWrite()->dict.erase(attr_key); + if (node->attrs->dict.size() == 0) { + node->attrs = NullValue(); + } + } + return input; +} + // Namespace containing detail implementations namespace detail { using runtime::TVMArgValue; diff --git a/include/tvm/te/operation.h b/include/tvm/te/operation.h index 89074d83e1d6..99c86f0d58f7 100644 --- a/include/tvm/te/operation.h +++ b/include/tvm/te/operation.h @@ -268,6 +268,7 @@ class ComputeOp : public Operation { Array axis, Array body); TVM_DEFINE_OBJECT_REF_METHODS(ComputeOp, Operation, ComputeOpNode); + TVM_DEFINE_OBJECT_REF_COW_METHOD(ComputeOpNode); }; /*! diff --git a/include/tvm/te/schedule.h b/include/tvm/te/schedule.h index 17aedbcff308..8e637b43b52e 100644 --- a/include/tvm/te/schedule.h +++ b/include/tvm/te/schedule.h @@ -29,6 +29,7 @@ #include #include #include +#include #include #include @@ -256,6 +257,41 @@ class Stage : public ObjectRef { * \return reference to self. */ TVM_DLL Stage& rolling_buffer(); // NOLINT(*) + /*! + * \brief Defines a layout transformation to be applied to the buffer. + * + * The map from initial_index to final_index must be an + * invertible affine transformation. + * + * \param initial_indices An array of variables to represent a + * value's location in the tensor, using the pre-transformation + * layout. These variables are used as binding occurrences to + * represent the initial indices when applying the initial->final + * mapping, and should not occur elsewhere in the + * Schedule. (i.e. Pass in newly constructed variables, not the + * initial IterVar::var) + * + * \param final_indices An array of expressions, giving the + * value's location in the tensor, using the post-transformation layout. + * Expressions should be in terms of the variables given in + * initial_indices. + * + * \param out_iter_vars An optional output location for the updated + * loop iteration variables. + * + * \return reference to self + */ + TVM_DLL Stage& transform_layout(const Array& initial_indices, + const Array& final_indices, + Array* out_iter_vars = nullptr); + /*! \brief Defines separators between groups of axes. + * + * Used to define `BufferNode::axis_separators`, which has + * additional details. + * + * \param axis_separators A list of axis separators. + */ + TVM_DLL Stage& set_axis_separators(const Array& axis_separators); /*! * \brief whether the stage has been scheduled. * \return whether the stage has been scheduled. @@ -466,9 +502,27 @@ class StageNode : public Object { * while origin_op remains fixed. */ Operation origin_op; - /*! \brief All the nodes in the iter var */ + /*! \brief All the nodes in the iter var + * + * Each element of all_iter_vars represents an iteration variable + * that may appear within this stage's computation. Any element + * of `all_iter_vars` that is in `leaf_iter_vars` represents a + * variable that is directly defined and usable within the stage's + * computation. All other elements of `all_iter_vars` represent + * variables whose value must be computed from the variables in + * `leaf_iter_vars`. (e.g. Support index k has been split by + * ``ko, ki = s.split(k, factor=4)``. ko and ki will appear in + * `leaf_iter_vars`, while k will not, and must be computed as + * `4*ko + ki`. + */ Array all_iter_vars; - /*! \brief The current active leaf iter vars in the stage. */ + /*! \brief The current active leaf iter vars in the stage. + * + * Each element of leaf_iter_vars will either be replaced with the + * bound index (e.g. threadIdx.x), or will be expanded into a loop + * over the variable's extent. `leaf_iter_vars` is a subset of + * `all_iter_vars`. + */ Array leaf_iter_vars; /*! * \brief Specify threads to be launched at the stage. @@ -500,6 +554,14 @@ class StageNode : public Object { bool double_buffer{false}; /*! \brief Whether apply rolling buffer optimization to this stage */ bool rolling_buffer{false}; + /*! \brief Layout transformations to be applied onto the stage's tensors. */ + Array layout_transforms; + /*! \brief List of axes after which to divide physical axes. + * + * Used to populate `BufferNode::axis_separators`, which has + * additional details. + */ + Array axis_separators; /*! * \brief The parent group of the current stage. * The stage cannot be assigned to stages outside the group. @@ -522,6 +584,8 @@ class StageNode : public Object { v->Visit("scope", &scope); v->Visit("is_output", &is_output); v->Visit("double_buffer", &double_buffer); + v->Visit("layout_transforms", &layout_transforms); + v->Visit("axis_separators", &axis_separators); v->Visit("group", &group); v->Visit("num_child_stages", &num_child_stages); } @@ -771,6 +835,61 @@ class Singleton : public IterVarRelation { TVM_DEFINE_OBJECT_REF_METHODS(Singleton, IterVarRelation, SingletonNode); }; +/*! + * \brief Transform iterator according to some arbitrary expression. + */ +class TransformNode : public IterVarRelationNode { + public: + /*! \brief The loop variables that were replaced by the transformation. + * + * Prior to applying a layout transformation, these represent the + * loops to iterate over a tensor as it is being computed, following + * a row-major traversal of the tensor's original shape in the + * compute definition. + */ + Array original_variables; + + /*! \brief The variables generated by the transformation. + * + * After to applying a layout transformation, these represent the + * loops to iterate over a tensor as it is being computed, following + * a row-major traversal of the transformed shape of the tensor. + */ + Array transformed_variables; + + /*! \brief Map from the original variables to the transformed variables. + * + * Used to determine iterator ranges over the transformed variables. + */ + IndexMap forward_transformation; + + /*! \brief Map from transformed variables to the original variables + * + * Used to rewrite expressions containing the original loop iterators + * in terms of the transformed loop iterators. + */ + IndexMap inverse_transformation; + + void VisitAttrs(AttrVisitor* v) { + v->Visit("original_variables", &original_variables); + v->Visit("transformed_variables", &transformed_variables); + v->Visit("forward_transformation", &forward_transformation); + v->Visit("inverse_transformation", &inverse_transformation); + } + + static constexpr const char* _type_key = "Transform"; + TVM_DECLARE_FINAL_OBJECT_INFO(TransformNode, IterVarRelationNode); +}; + +class Transform : public IterVarRelation { + public: + TVM_DLL explicit Transform(Array original_variables, + Array transformed_variables, IndexMap forward_transformation, + IndexMap inverse_transformation); + + TVM_DEFINE_OBJECT_REF_METHODS(Transform, IterVarRelation, TransformNode); +}; + /*! \brief Container for specialization conditions. */ class SpecializedConditionNode : public Object { public: diff --git a/include/tvm/tir/buffer.h b/include/tvm/tir/buffer.h index 69453e23ac1a..aef82ae368d0 100644 --- a/include/tvm/tir/buffer.h +++ b/include/tvm/tir/buffer.h @@ -55,8 +55,22 @@ class BufferNode : public Object { Var data; /*! \brief data type in the content of the tensor */ DataType dtype; - /*! \brief The shape of the buffer */ + /*! \brief The type of the buffer prior to flattening + * + * This contains the shape as it is accessed by + * BufferLoad/BufferStore nodes, and used by the low-level code + * generators. + */ Array shape; + /*! + * \brief Separators between input axes when generating flattened output axes + * + * For buffers representing flat 1-d memory (e.g. any buffer in + * RAM), this should be an empty array. For buffers representing + * non-flat memory, each entry in axis_separators should be the + * first input axis that is part of a new flattened axis. + */ + Array axis_separators; /*! * \brief The strides of each dimension * This can be an empty array, indicating array is contiguous @@ -89,6 +103,7 @@ class BufferNode : public Object { v->Visit("dtype", &dtype); v->Visit("shape", &shape); v->Visit("strides", &strides); + v->Visit("axis_separators", &axis_separators); v->Visit("elem_offset", &elem_offset); v->Visit("name", &name); v->Visit("data_alignment", &data_alignment); @@ -98,10 +113,11 @@ class BufferNode : public Object { } bool SEqualReduce(const BufferNode* other, SEqualReducer equal) const { - // Use DefEqual as buffer can define variables - // in its semantics, skip name as name is not important. + // Use DefEqual as buffer can define variables in its semantics, + // skip name as name is not important. return equal.DefEqual(data, other->data) && equal(dtype, other->dtype) && equal.DefEqual(shape, other->shape) && equal.DefEqual(strides, other->strides) && + equal.DefEqual(axis_separators, other->axis_separators) && equal.DefEqual(elem_offset, other->elem_offset) && equal(data_alignment, other->data_alignment) && equal(buffer_type, other->buffer_type); } @@ -112,6 +128,7 @@ class BufferNode : public Object { hash_reduce.DefHash(shape); hash_reduce.DefHash(strides); hash_reduce.DefHash(elem_offset); + hash_reduce.DefHash(axis_separators); hash_reduce(data_alignment); hash_reduce(buffer_type); } @@ -127,7 +144,7 @@ class BufferNode : public Object { * without adjusting for number of lanes. (e.g. The number of * float16x4 elements in a buffer of type float16x4.) */ - PrimExpr ElemOffset(Array index) const; + Array ElemOffset(Array index) const; static constexpr const char* _type_key = "tir.Buffer"; static constexpr const bool _type_has_method_sequal_reduce = true; @@ -146,7 +163,7 @@ class Buffer : public ObjectRef { // A default value will be picked. TVM_DLL Buffer(Var data, DataType dtype, Array shape, Array strides, PrimExpr elem_offset, String name, int data_alignment, int offset_factor, - BufferType buffer_type, Span span = Span()); + BufferType buffer_type, Array axis_separators = {}, Span span = Span()); /*! * \brief Return a new buffer that is equivalent with current one @@ -186,6 +203,19 @@ class Buffer : public ObjectRef { */ TVM_DLL Stmt vstore(Array begin, PrimExpr value) const; + /*! + * \brief Get a flattened version of the buffer + */ + Buffer GetFlattenedBuffer() const; + + /*! \brief Determine the offset in the buffer of the given index. + * + * Returns the buffer offset, in number of elements of type dtype, + * without adjusting for number of lanes. (e.g. The number of + * float16x4 elements in a buffer of type float16x4.) + */ + Array OffsetOf(Array index) const; + /*! * \brief Return the storage scope associated with this buffer. */ @@ -201,12 +231,14 @@ class Buffer : public ObjectRef { * \param dtype The content data type. * \param name The name of the buffer * \param storage_scope The storage scope associated with this buffer + * \param axis_separators Divisions defining the groups of axes that will be flattened together. * \param span The location of this object in the source code. * \return The created buffer. * \sa Buffer for complete constructor. */ TVM_DLL Buffer decl_buffer(Array shape, DataType dtype = DataType::Float(32), - String name = "buffer", String storage_scope = "", Span span = Span()); + String name = "buffer", String storage_scope = "", + Array axis_separators = {}, Span span = Span()); /*! * \brief Base node for data producers. diff --git a/include/tvm/tir/builtin.h b/include/tvm/tir/builtin.h index d8a5ea67d844..f7e1cfbc3e6d 100644 --- a/include/tvm/tir/builtin.h +++ b/include/tvm/tir/builtin.h @@ -105,10 +105,15 @@ TVM_DLL const Op& large_uint_imm(); TVM_DLL const Op& q_multiply_shift(); /*! - * \brief See pesudo code + * \brief Returns the address of an element in the buffer (see pseudocode below). + * + * The number of indices should match the dimensionality of the buffer + * being accessed. If this operation occurs after buffer flattening, + * the number of indices must be supported by the target (i.e. N>1 + * only on targets that support non-flat memory buffers). * - * Handle address_of(Load *op) { - * return &op->buffer_var[index]; + * Handle address_of(BufferLoad *op) { + * return &op->buffer_var[op->indices[0], op->indices[1], ..., op->indices[N-1]]; * } */ TVM_DLL const Op& address_of(); diff --git a/include/tvm/tir/expr.h b/include/tvm/tir/expr.h index f6741112f269..674ff0b7f43c 100644 --- a/include/tvm/tir/expr.h +++ b/include/tvm/tir/expr.h @@ -630,6 +630,22 @@ class BufferLoadNode : public PrimExprNode { static constexpr const char* _type_key = "tir.BufferLoad"; TVM_DECLARE_FINAL_OBJECT_INFO(BufferLoadNode, PrimExprNode); + + private: + /*! \brief Set the dtype based on the buffer/indices + * + * Usually, the BufferLoad's dtype will be the same dtype as the + * buffer. This may have a different number of lanes than the + * buffer's dtype if index values have more than 1 lane. + * + * This function should only be called during construction and after + * CopyOnWrite. Friend class used here to restrict usage. + */ + void LegalizeDType(); + friend class BufferLoad; + friend class CustomDatatypesLowerer; + friend class VectorTypeRewriter; + friend class Vectorizer; }; /*! diff --git a/include/tvm/tir/function.h b/include/tvm/tir/function.h index 2b3c4b5fe003..dc7014cc8aab 100644 --- a/include/tvm/tir/function.h +++ b/include/tvm/tir/function.h @@ -91,11 +91,30 @@ class PrimFuncNode : public BaseFuncNode { */ Map buffer_map; + /*! \brief The buffer map prior to flattening. + * + * This contains the buffers as they exists prior to flattening, and + * is used for validating an input tensor passed into the packed + * API. Any buffer that is present in `buffer_map` but not present + * in `preflattened_buffer_map` is assumed to be the same before + * and after flattening (e.g. a 1-d tensor that is backed by 1-d + * flat memory). + * + * TODO(Lunderberg): Remove preflattened_buffer_map, and instead + * declare each flattened buffer as aliasing the original tensor + * shape. This should include improving the StmtExprMutator to + * provide easier interactions with Buffer objects, so that the + * bookkeeping of relationships between buffers doesn't need to be + * repeated across several transforms. + */ + Map preflattened_buffer_map; + void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("params", ¶ms); v->Visit("body", &body); v->Visit("ret_type", &ret_type); v->Visit("buffer_map", &buffer_map); + v->Visit("preflattened_buffer_map", &preflattened_buffer_map); v->Visit("attrs", &attrs); v->Visit("span", &span); v->Visit("_checked_type_", &checked_type_); @@ -104,6 +123,7 @@ class PrimFuncNode : public BaseFuncNode { bool SEqualReduce(const PrimFuncNode* other, SEqualReducer equal) const { // visit params and buffer_map first as they contains defs. return equal.DefEqual(params, other->params) && equal(buffer_map, other->buffer_map) && + equal(preflattened_buffer_map, other->preflattened_buffer_map) && equal(ret_type, other->ret_type) && equal(body, other->body) && equal(attrs, other->attrs); } @@ -111,6 +131,7 @@ class PrimFuncNode : public BaseFuncNode { void SHashReduce(SHashReducer hash_reduce) const { hash_reduce.DefHash(params); hash_reduce(buffer_map); + hash_reduce(preflattened_buffer_map); hash_reduce(ret_type); hash_reduce(body); hash_reduce(attrs); @@ -136,16 +157,33 @@ class PrimFunc : public BaseFunc { public: /*! * \brief Constructor + * * \param params The parameters of the function. + * * \param body The body of the function. + * * \param ret_type The return type of the function. + * * \param buffer_map The buffer map for parameter buffer unpacking. + * This contains buffer objects as they appear in the body of the + * PrimFunc. (e.g. a buffer of shape ``[1024]`` originally + * generated as a tensor of shape ``[32, 32]``) + * + * \param preflattened_buffer_map The buffer map for + * parameter buffer unpacking. This contains buffer + * objects as they are expected to be passed in by the + * callee. (e.g. a buffer of shape ``[32, 32]`` originally + * generated as a tensor of shape ``[32, 32]``) + * * \param attrs Additional function attributes. + * * \param span The location of this object in the source code. */ - TVM_DLL PrimFunc(Array params, Stmt body, Type ret_type = VoidType(), - Map buffer_map = Map(), - DictAttrs attrs = NullValue(), Span span = Span()); + TVM_DLL PrimFunc( + Array params, Stmt body, Type ret_type = VoidType(), + Map buffer_map = Map(), + Optional> preflattened_buffer_map = Optional>(), + DictAttrs attrs = NullValue(), Span span = Span()); TVM_DEFINE_OBJECT_REF_METHODS(PrimFunc, BaseFunc, PrimFuncNode); TVM_DEFINE_OBJECT_REF_COW_METHOD(PrimFuncNode); diff --git a/include/tvm/tir/index_map.h b/include/tvm/tir/index_map.h new file mode 100644 index 000000000000..237111306c2a --- /dev/null +++ b/include/tvm/tir/index_map.h @@ -0,0 +1,140 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tvm/tir/index_map.h + * \brief Defines a remapping of buffer indices + * + * For use with tvm::tir::Buffer. + */ +#ifndef TVM_TIR_INDEX_MAP_H_ +#define TVM_TIR_INDEX_MAP_H_ + +#include +#include +#include +#include + +namespace tvm { +namespace tir { + +/*! + * \brief Defines a mapping between two representations of indices + * into a buffer. + * + * This is primarily used for layout transformations of Buffer + * objects. + */ +class IndexMapNode : public Object { + public: + /*! \brief Variables representing the indices prior to remapping. + * + * If initial_indices is empty, then final_indices should also be + * empty, and no mapping is applied. + */ + Array initial_indices; + + /*! + * \brief Expressions defining the indices after remapping. + * + * These expressions should only be in terms of the initial_indices, + * and must be expressible as an IterSumExpr. The mapping from + * initial_indices to final_indices must be injective. + * + * If final_indices is empty, then initial_indices should also be + * empty, and the map is an identity function. + */ + Array final_indices; + + /*! + * \brief Default constructor + * + * Defines the mapping as an identity function, with initial_indices + * equal to the final indices. + */ + IndexMapNode() {} + + /*! + * \brief Map indices to the output space + * + * \param indices The indices in the input space. Should contain + * one value for each variable in `initial_indices`. + * + * \returns The indices in the output space. Contains one value for + * each expression in `final_indices`. + */ + Array MapIndices(const Array& indices) const; + + /*! \brief Map a memory range to the output space + * + * If contiguous memory locations in the input space are not + * necessarily contiguous in the output space (e.g. `lambda i: + * [8*(i%8) + (i//8)]`), then this will return the smallest range + * such that all valid indices are contained within the given range. + * + * \param ranges The ranges in the input space. Should contain one + * value for each variable in `initial_indices`. + * + * \returns The ranges in the output space. Contains one value for + * each expression in `final_indices`. + */ + Array MapRanges(const Array& ranges) const; + + /*! \brief Map a buffer shape to the output space + * + * \param shape The buffer shape in the input space. Should contain + * one value for each variable in `initial_indices`. + * + * \returns The buffer shape in the output space. Contains one + * value for each expression in `final_indices`. + */ + Array MapShape(const Array& shape) const; + + void VisitAttrs(AttrVisitor* v) { + v->Visit("initial_indices", &initial_indices); + v->Visit("final_indices", &final_indices); + } + + TVM_DECLARE_FINAL_OBJECT_INFO(IndexMapNode, Object); +}; + +class IndexMap : public ObjectRef { + public: + IndexMap(Array initial_indices, Array final_indices); + + /*! \brief Generate the inverse mapping. + * + * The range of the input indices is required in order to ensure + * that the transformation is bijective over the input domain. + * + * TODO(Lunderberg): Look into allowing non-bijective + * transformations. If injective, the inverse mapping could still + * be generated with some predicate. If non-injective, could + * simplify the implementation of other optimizations (e.g. double + * buffering as a map `lambda *indices: [buffer_loop%2, *indices]`). + */ + IndexMap Inverse(Array initial_ranges) const; + + TVM_DEFINE_OBJECT_REF_METHODS(IndexMap, ObjectRef, IndexMapNode); +}; + +} // namespace tir +} // namespace tvm + +#endif // TVM_TIR_INDEX_MAP_H_ diff --git a/include/tvm/tir/stmt.h b/include/tvm/tir/stmt.h index 972f78171569..9ccab50eced2 100644 --- a/include/tvm/tir/stmt.h +++ b/include/tvm/tir/stmt.h @@ -388,6 +388,7 @@ class BufferRealize : public Stmt { Span span = Span()); TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(BufferRealize, Stmt, BufferRealizeNode); + TVM_DEFINE_OBJECT_REF_COW_METHOD(BufferRealizeNode); }; /*! @@ -585,6 +586,7 @@ class Allocate : public Stmt { Span span = Span()); TVM_DEFINE_OBJECT_REF_METHODS(Allocate, Stmt, AllocateNode); + TVM_DEFINE_OBJECT_REF_COW_METHOD(AllocateNode); }; /*! @@ -1372,6 +1374,21 @@ constexpr const char* pragma_tensor_core = "pragma_tensor_core"; * run prefetch of Tensor on the current loop scope */ constexpr const char* prefetch_scope = "prefetch_scope"; +/*! + * \brief Marks the layout transforms to be used for a tensor. + * + * Only applies to a DataProducer, as it should be made part of the + * PrimFunc attributes for TIR. + */ +constexpr const char* layout_transforms = "layout_transforms"; +/*! + * \brief Marks the physical axis separators + * + * Only applies to a DataProducer, as it should be made part of the + * Buffer definition in a PrimFunc. See `BufferNode::axis_separators` + * for more details. + */ +constexpr const char* axis_separators = "axis_separators"; /*! * \brief Marks production of double buffer data */ diff --git a/include/tvm/topi/transform.h b/include/tvm/topi/transform.h index 04027f8974fe..ef36c015957a 100644 --- a/include/tvm/topi/transform.h +++ b/include/tvm/topi/transform.h @@ -1566,7 +1566,11 @@ inline Array meshgrid(const Array& inputs, const std::string& in out_shape, [&](const Array& indices) { const int src_index = (cartesian_indexing && i < 2) ? 1 - i : i; - Array real_indices = {indices[src_index]}; + auto ndim = inputs[i]->GetShape().size(); + Array real_indices = {}; + if (ndim > 0) { + real_indices = {indices[src_index]}; + } return inputs[i](real_indices); }, name, tag)); @@ -1815,7 +1819,7 @@ inline Tensor sparse_to_dense(const Tensor& sparse_indices, const Array& indices) { PrimExpr ret = default_value; if (0 == rank_sparse_indices) { - ret = if_then_else(indices[0] == sparse_indices[0], sparse_values[0], ret); + ret = if_then_else(indices[0] == sparse_indices(), sparse_values(), ret); } else if (1 == rank_sparse_indices) { for (int j = 0; j < GetConstInt(sparse_indices->shape[0]); j++) { ret = if_then_else(indices[0] == sparse_indices[j], sparse_values[j], ret); diff --git a/python/tvm/relay/backend/contrib/ethosu/tir/binary_elementwise.py b/python/tvm/relay/backend/contrib/ethosu/tir/binary_elementwise.py index 53b46aeafbf5..11e472070a6b 100644 --- a/python/tvm/relay/backend/contrib/ethosu/tir/binary_elementwise.py +++ b/python/tvm/relay/backend/contrib/ethosu/tir/binary_elementwise.py @@ -77,12 +77,12 @@ def get_binary_elementwise_params( _, _, _, _, _, inner = get_outer_loops(body, "NHWC") op = ignore_cast(inner.value) - input_pointer = ignore_cast(op.a).buffer_var - input_pointer1 = ignore_cast(op.b).buffer_var + input_pointer = ignore_cast(op.a).buffer.data + input_pointer1 = ignore_cast(op.b).buffer.data if reversed_operands: input_pointer, input_pointer1 = input_pointer1, input_pointer - output_pointer = inner.buffer_var + output_pointer = inner.buffer.data # Get feature map info serial_ifm, _ = get_ifm_params(input_pointer, producers) serial_ifm2, _ = get_ifm_params(input_pointer1, producers) diff --git a/python/tvm/relay/backend/contrib/ethosu/tir/convolution.py b/python/tvm/relay/backend/contrib/ethosu/tir/convolution.py index 50c27cc01689..bdca6a874ca5 100644 --- a/python/tvm/relay/backend/contrib/ethosu/tir/convolution.py +++ b/python/tvm/relay/backend/contrib/ethosu/tir/convolution.py @@ -59,8 +59,8 @@ def get_conv2d_params(stmt, producers, consumers): loads = get_loads(rc.body) # stores = [output] stores = get_stores(rc.body) - input_pointer = loads[1].buffer_var - output_pointer = stores[0].buffer_var + input_pointer = loads[1].buffer.data + output_pointer = stores[0].buffer.data # Get feature map info serial_ifm, serial_padding = get_ifm_params(input_pointer, producers) serial_ofm, replace_pointer, is_allocator = get_ofm_params(output_pointer, consumers, producers) @@ -75,16 +75,16 @@ def get_conv2d_params(stmt, producers, consumers): ) # Get scale_bias info scale_bias_load = loads[3] - scale_bias_base = get_base_address(scale_bias_load.index) + scale_bias_base = [get_base_address(index) for index in scale_bias_load.indices] serial_scale_bias = SerialAddressRange( - address=tvm.tir.Load("uint8", scale_bias_load.buffer_var, scale_bias_base), + address=tvm.tir.BufferLoad(scale_bias_load.buffer, scale_bias_base), length=SCALE_BIAS_LENGTH * serial_ofm[3], ) # Get weight info weight_load = loads[2] - weight_base = get_base_address(weight_load.index) + weight_base = [get_base_address(index) for index in weight_load.indices] serial_weight = SerialAddressRange( - address=tvm.tir.Load("uint8", weight_load.buffer_var, weight_base), + address=tvm.tir.BufferLoad(weight_load.buffer, weight_base), length=serial_ofm[3] * serial_kernel[0] * serial_kernel[1] * rc.extent, ) # Get activation info diff --git a/python/tvm/relay/backend/contrib/ethosu/tir/depthwise.py b/python/tvm/relay/backend/contrib/ethosu/tir/depthwise.py index b1a4ebd82a88..b39ec36e4231 100644 --- a/python/tvm/relay/backend/contrib/ethosu/tir/depthwise.py +++ b/python/tvm/relay/backend/contrib/ethosu/tir/depthwise.py @@ -68,8 +68,8 @@ def get_depthwise_conv2d_params( loads = get_loads(rw.body) # stores = [output] stores = get_stores(rw.body) - input_pointer = loads[1].buffer_var - output_pointer = stores[0].buffer_var + input_pointer = loads[1].buffer.data + output_pointer = stores[0].buffer.data # Get feature map info serial_ifm, serial_padding = get_ifm_params(input_pointer, producers) serial_ofm, replace_pointer, is_allocator = get_ofm_params(output_pointer, consumers, producers) @@ -84,16 +84,16 @@ def get_depthwise_conv2d_params( ) # Get scale_bias info scale_bias_load = loads[3] - scale_bias_base = get_base_address(scale_bias_load.index) + scale_bias_base = [get_base_address(index) for index in scale_bias_load.indices] serial_scale_bias = SerialAddressRange( - address=tvm.tir.Load("uint8", scale_bias_load.buffer_var, scale_bias_base), + address=tvm.tir.BufferLoad(scale_bias_load.buffer, scale_bias_base), length=SCALE_BIAS_LENGTH * serial_ofm[3], ) # Get weight info weight_load = loads[2] - weight_base = get_base_address(weight_load.index) + weight_base = [get_base_address(index) for index in weight_load.indices] serial_weight = SerialAddressRange( - address=tvm.tir.Load("uint8", weight_load.buffer_var, weight_base), + address=tvm.tir.BufferLoad(weight_load.buffer, weight_base), length=serial_ofm[3] * serial_kernel[0] * serial_kernel[1], ) # Get activation info diff --git a/python/tvm/relay/backend/contrib/ethosu/tir/dma.py b/python/tvm/relay/backend/contrib/ethosu/tir/dma.py index 9f82d7478265..aa4c09f24d7c 100644 --- a/python/tvm/relay/backend/contrib/ethosu/tir/dma.py +++ b/python/tvm/relay/backend/contrib/ethosu/tir/dma.py @@ -41,12 +41,12 @@ def get_pad_params(stmt): """ _, body = get_op_attrs(stmt) n, h, w, c, _, inner = get_outer_loops(body, "NHWC") - output_pointer = inner.buffer_var + output_pointer = inner.buffer.data pad = SerialPadding(top=0, left=0, bottom=0, right=0) if isinstance(inner.value, tvm.tir.Call): - input_pointer = inner.value.args[1].buffer_var + input_pointer = inner.value.args[1].buffer.data else: - input_pointer = inner.value.buffer_var + input_pointer = inner.value.buffer.data return pad, input_pointer, output_pointer padded_shape = [n.extent, h.extent, w.extent, c.extent] @@ -94,10 +94,10 @@ def get_upscale_params(stmt): _, body = get_op_attrs(stmt) _, _, _, _, _, inner = get_outer_loops(body, "NHWC") if isinstance(inner.value, tvm.tir.Call): - input_pointer = inner.value.args[1].buffer_var + input_pointer = inner.value.args[1].buffer.data else: - input_pointer = inner.value.buffer_var - output_pointer = inner.buffer_var + input_pointer = inner.value.buffer.data + output_pointer = inner.buffer.data return (input_pointer, output_pointer) @@ -126,11 +126,11 @@ def get_convert_to_nhwc_params(stmt): # compute that is deemed uneccesary isn't removed by TVM. if attrs["layout"] == "NHCWB16": inner = inner.body - input_pointer = inner.value.b.buffer_var + input_pointer = inner.value.b.buffer.data else: - input_pointer = inner.value.buffer_var + input_pointer = inner.value.buffer.data - output_pointer = inner.buffer_var + output_pointer = inner.buffer.data return c.extent, input_pointer, output_pointer @@ -154,13 +154,13 @@ def get_convert_to_nhcwb16_params(stmt): """ attrs, body = get_op_attrs(stmt) _, _, _, c, b, inner = get_outer_loops(body, attrs["layout"]) - output_pointer = inner.buffer_var + output_pointer = inner.buffer.data if isinstance(inner.value, tvm.tir.Call): cond = inner.value.args[0] out_channels = cond.b.value - input_pointer = inner.value.args[1].buffer_var + input_pointer = inner.value.args[1].buffer.data else: - input_pointer = inner.value.buffer_var + input_pointer = inner.value.buffer.data out_channels = c.extent * b.extent if attrs["layout"] == "NHCWB16" else c.extent return out_channels, input_pointer, output_pointer @@ -186,12 +186,17 @@ def get_read_params(stmt): """ attrs, body = get_op_attrs(stmt) _, h, w, c, _, inner = get_outer_loops(body, attrs["layout"]) - input_pointer = inner.value.buffer_var - output_pointer = inner.buffer_var + input_pointer = inner.value.buffer.data + output_pointer = inner.buffer.data + + # Needed for stride calculation, can replace with + # inner.value.buffer.strides in future. + assert len(inner.value.indices) == 1, "Ethos-U DMA expects flattened buffers" stride_vars = [h.loop_var, w.loop_var, c.loop_var] - strides = get_strides(inner.value.index, stride_vars) - base_address = get_base_address(inner.value.index) - data_type = inner.buffer_var.type_annotation.element_type.dtype + strides = get_strides(inner.value.indices[0], stride_vars) + + base_address = [get_base_address(index) for index in inner.value.indices] + data_type = inner.buffer.data.type_annotation.element_type.dtype return ( SerialFeatureMap( data_type=data_type, @@ -201,7 +206,7 @@ def get_read_params(stmt): tile_height_0=h.extent, tile_height_1=0, tile_width_0=w.extent, - tile_address_0=tvm.tir.Load(data_type, inner.value.buffer_var, base_address), + tile_address_0=tvm.tir.BufferLoad(inner.value.buffer, base_address), tile_address_1=0, tile_address_2=0, tile_address_3=0, @@ -237,12 +242,17 @@ def get_write_params(stmt): """ attrs, body = get_op_attrs(stmt) _, h, w, c, _, inner = get_outer_loops(body, attrs["layout"]) - input_pointer = inner.value.buffer_var - output_pointer = inner.buffer_var + input_pointer = inner.value.buffer.data + output_pointer = inner.buffer.data + + # Needed for stride calculation, can replace with + # inner.value.buffer.strides in future. + assert len(inner.indices) == 1, "Ethos-U DMA expects flattened buffers" stride_vars = [h.loop_var, w.loop_var, c.loop_var] - strides = get_strides(inner.index, stride_vars) - base_address = get_base_address(inner.index) - data_type = inner.buffer_var.type_annotation.element_type.dtype + strides = get_strides(inner.indices[0], stride_vars) + + base_address = [get_base_address(index) for index in inner.indices] + data_type = inner.buffer.data.type_annotation.element_type.dtype return ( SerialFeatureMap( data_type=data_type, @@ -252,7 +262,7 @@ def get_write_params(stmt): tile_height_0=h.extent, tile_height_1=0, tile_width_0=w.extent, - tile_address_0=tvm.tir.Load(data_type, inner.buffer_var, base_address), + tile_address_0=tvm.tir.BufferLoad(inner.buffer, base_address), tile_address_1=0, tile_address_2=0, tile_address_3=0, diff --git a/python/tvm/relay/backend/contrib/ethosu/tir/identity.py b/python/tvm/relay/backend/contrib/ethosu/tir/identity.py index 6dccb5a15c97..40686ac2336f 100644 --- a/python/tvm/relay/backend/contrib/ethosu/tir/identity.py +++ b/python/tvm/relay/backend/contrib/ethosu/tir/identity.py @@ -59,12 +59,14 @@ def _get_feature_map(stmt: tvm.tir.AttrStmt, fm_type: str) -> Tuple[SerialFeatur fm_inner = inner.value if fm_type == "ifm" else inner + # Needed for stride calculation, can replace with + # inner.value.buffer.strides in future. + assert len(fm_inner.indices) == 1, "Ethos-U passes expect flattened buffers" stride_vars = [l.loop_var for l in loops] - strides = get_strides(fm_inner.index, stride_vars) + strides = get_strides(fm_inner.indices[0], stride_vars) - base_address = get_base_address(fm_inner.index) - data_type = inner.buffer_var.type_annotation.element_type.dtype - pointer = fm_inner.buffer_var + base_address = [get_base_address(index) for index in fm_inner.indices] + data_type = inner.buffer.data.type_annotation.element_type.dtype serial_feature_map = SerialFeatureMap( data_type=data_type, @@ -74,7 +76,7 @@ def _get_feature_map(stmt: tvm.tir.AttrStmt, fm_type: str) -> Tuple[SerialFeatur tile_height_0=loops[0].extent, tile_height_1=0, tile_width_0=loops[1].extent if len(loops) > 1 else 1, - tile_address_0=tvm.tir.Load(data_type, pointer, base_address), + tile_address_0=tvm.tir.BufferLoad(fm_inner.buffer, base_address), tile_address_1=0, tile_address_2=0, tile_address_3=0, @@ -86,7 +88,7 @@ def _get_feature_map(stmt: tvm.tir.AttrStmt, fm_type: str) -> Tuple[SerialFeatur stride_c=strides[2] if len(strides) > 2 else 1, ) - output_pointer = inner.buffer_var + output_pointer = inner.buffer.data return serial_feature_map, output_pointer @@ -130,8 +132,8 @@ def get_identity_params( # loads = [input, LUT, LUT] loads = get_loads(stmt) - input_pointer = loads[0].buffer_var - output_pointer = stmt.buffer_var + input_pointer = loads[0].buffer.data + output_pointer = stmt.buffer.data read = producers[input_pointer] write = consumers[output_pointer] diff --git a/python/tvm/relay/backend/contrib/ethosu/tir/passes.py b/python/tvm/relay/backend/contrib/ethosu/tir/passes.py index c2fff8abb9b0..5f0b9fe3b690 100644 --- a/python/tvm/relay/backend/contrib/ethosu/tir/passes.py +++ b/python/tvm/relay/backend/contrib/ethosu/tir/passes.py @@ -28,7 +28,7 @@ from .identity import get_identity_params from .unary_elementwise import get_unary_elementwise_params from .transform import get_copy_params -from .utils import get_weights_pointer, get_scale_bias_pointer +from .utils import get_weights_buffer, get_scale_bias_buffer def RemoveZeroStores(): @@ -82,8 +82,8 @@ def _resolve_pointers(stmt): loads = [] def _get_loads(stmt): - if isinstance(stmt, tvm.tir.Load): - loads.append(stmt.buffer_var) + if isinstance(stmt, tvm.tir.BufferLoad): + loads.append(stmt.buffer.data) if isinstance(stmt, tvm.tir.Allocate): pointer_to_extents[stmt.buffer_var] = stmt.extents @@ -94,8 +94,8 @@ def _get_loads(stmt): elif isinstance(stmt, tvm.tir.AttrStmt): if stmt.attr_key == "pragma_op": tvm.tir.stmt_functor.post_order_visit(stmt, _get_loads) - for load_buffer in loads: - pointer_to_consumer[load_buffer] = stmt + for load_pointer in loads: + pointer_to_consumer[load_pointer] = stmt def _replace_operator(stmt): """Replace operators with call_externs, having derived the parameters @@ -232,21 +232,26 @@ def DivideConstants(const_dict): def _visit(stmt): new_args = [] for i, arg in enumerate(stmt.args): - if isinstance(arg, tvm.tir.expr.Load): + if isinstance(arg, tvm.tir.expr.BufferLoad): # If we're trying to load a buffer that maps to a constant - if arg.buffer_var in buffer_to_const: - const = buffer_to_const[arg.buffer_var] - offset = int(arg.index) + if arg.buffer.data in buffer_to_const: + const = buffer_to_const[arg.buffer.data] + + assert len(arg.indices) == 1, "Ethos-U passes expects flattened buffers" + + offset = int(arg.indices[0]) # Note by convention the arg after a constant read is the length of the read length = int(stmt.args[i + 1]) # If it's anything other than a full read, create a new buffer if offset != 0 or len(const) != length: new_consts.append(const[offset : offset + length]) - new_buffer = tvm.tir.decl_buffer((length,), arg.dtype) + new_buffer = tvm.tir.decl_buffer( + (length,), arg.dtype, scope=arg.buffer.scope() + ) new_buffers.append(new_buffer) - new_args.append(tvm.tir.expr.Load(new_buffer.dtype, new_buffer.data, 0)) + new_args.append(tvm.tir.expr.BufferLoad(new_buffer, [0])) continue - keep_buffers.add(arg.buffer_var) + keep_buffers.add(arg.buffer.data) new_args.append(arg) @@ -278,7 +283,15 @@ def _ftransform(f, mod, ctx): new_buffer_map[handle] = new_buffer new_const_dict[len(new_params) - 1] = new_consts[i] - new_f = tvm.tir.PrimFunc(new_params, new_body, f.ret_type, new_buffer_map, f.attrs, f.span) + new_f = tvm.tir.PrimFunc( + new_params, + new_body, + f.ret_type, + new_buffer_map, + f.preflattened_buffer_map, + f.attrs, + f.span, + ) return new_f def _divide_constants(mod): @@ -302,179 +315,232 @@ def EncodeConstants(const_dict): """ new_const_dict = {} - buffer_to_const = {} - pointer_to_buffer = {} - rewrite_buffer = {} - rewrite_pointer = {} - accel_config = vela_api.get_accelerator_config() - - def _align_scale_bias(tir_extern_call, bias): - """Align the scale_bias to 16 bytes.""" - value_bytes = bytearray() - value_bytes.extend(bias.tobytes()) - # Align to 16 - remainder = (len(value_bytes)) % 16 - if remainder > 0: - value_bytes.extend(bytearray(16 - remainder)) - value = np.frombuffer(value_bytes, dtype="uint8") - return value - - def _encode_weights(tir_extern_call, weights): - """Encode the weights for a TIR extern call.""" - value_bytes = vela_api.encode_weights(tir_extern_call, weights, accel_config) - value = np.frombuffer(value_bytes, dtype="uint8") - return value - - def _new_buffer(old_buffer, new_value): - """Create a new buffer and add the old buffer and its pointer to the - rewriting maps.""" - if old_buffer in rewrite_buffer: - new_buffer = rewrite_buffer[old_buffer] - else: - new_buffer = tvm.tir.decl_buffer((len(new_value),), str(new_value.dtype)) - pointer_to_buffer[new_buffer.data] = new_buffer - buffer_to_const[new_buffer] = new_value - - rewrite_buffer[old_buffer] = new_buffer - rewrite_pointer[old_buffer.data] = new_buffer.data - - def _visit_encode_pre(stmt): - if isinstance(stmt, tvm.tir.Call): - # Handle copies as a special-case by propagating the buffer information - # from the read to the write pointer. - if stmt.args[0] == "ethosu_copy": - read_pointer = stmt.args[1].buffer_var - if read_pointer in pointer_to_buffer: - write_pointer = stmt.args[3].buffer_var + + def collect_encoding_definitions(stmt, old_buffer_to_const): + # Map from copy destination to copy source. + copy_map = {} + # List of buffer copies that occurred + copied_buffers = [] + # List of encoded buffer information + constant_buffer_replacements = [] + + def _align_scale_bias(tir_extern_call, bias): + """Align the scale_bias to 16 bytes.""" + value_bytes = bytearray() + value_bytes.extend(bias.tobytes()) + # Align to 16 + remainder = (len(value_bytes)) % 16 + if remainder > 0: + value_bytes.extend(bytearray(16 - remainder)) + value = np.frombuffer(value_bytes, dtype="uint8") + return value + + accel_config = vela_api.get_accelerator_config() + + def _encode_weights(tir_extern_call, weights): + """Encode the weights for a TIR extern call.""" + value_bytes = vela_api.encode_weights(tir_extern_call, weights, accel_config) + value = np.frombuffer(value_bytes, dtype="uint8") + return value + + def _declare_constant_buffer(old_buffer, encoded_constants): + """Create a new buffer and add the old buffer and its pointer to the + rewriting maps.""" + new_buffer = tvm.tir.decl_buffer( + shape=[len(encoded_constants)], + dtype=str(encoded_constants.dtype), + name=old_buffer.name + "_encoded", + scope=old_buffer.scope(), + ) + + constant_buffer_replacements.append( + { + "old_buffer": old_buffer, + "new_buffer": new_buffer, + "encoded_constants": encoded_constants, + } + ) + + def _visit(stmt): + if isinstance(stmt, tvm.tir.Call): + # Handle copies as a special-case by propagating the buffer information + # from the read to the write pointer. + if stmt.args[0] == "ethosu_copy": + read_buffer = stmt.args[1].buffer + write_buffer = stmt.args[3].buffer # Assert writing to the base of the write_var (pre-StorageRewrite) - assert stmt.args[3].index == 0 - assert stmt.args[1].index == 0 - pointer_to_buffer[write_pointer] = pointer_to_buffer[read_pointer] - else: - # Encode the weights - weights_pointer = get_weights_pointer(stmt) - if weights_pointer is not None: - assert weights_pointer in pointer_to_buffer - weights_buffer = pointer_to_buffer[weights_pointer] - weights_value = buffer_to_const[weights_buffer] - new_weights_value = _encode_weights(stmt, weights_value) - _new_buffer(weights_buffer, new_weights_value) - # Align the scale_bias to 16 bytes - scale_bias_pointer = get_scale_bias_pointer(stmt) - if scale_bias_pointer is not None: - assert scale_bias_pointer in pointer_to_buffer - scale_bias_buffer = pointer_to_buffer[scale_bias_pointer] - scale_bias_value = buffer_to_const[scale_bias_buffer] - new_scale_bias_value = _align_scale_bias(stmt, scale_bias_value) - _new_buffer(scale_bias_buffer, new_scale_bias_value) - - def _visit_encode_post(stmt): - # Because encoding may change the data type (e.g. bias to uint8) and type information - # is stored in pointer vars, it's necessary to rewrite all the pointers which point - # to encoded data. - if isinstance(stmt, tvm.tir.Allocate): - allocate_pointer = stmt.buffer_var - if allocate_pointer in pointer_to_buffer: - buffer = pointer_to_buffer[allocate_pointer] - if buffer in rewrite_buffer: # If the pointer needs rewriting - # Create a new pointer var with the type of the new buffer - new_buffer = rewrite_buffer[buffer] - storage_type = tvm.ir.PrimType(new_buffer.dtype) - new_pointer = tvm.tir.Var( - allocate_pointer.name, - tvm.ir.PointerType(storage_type, buffer.scope()), - allocate_pointer.span, - ) - # Set the new pointer to resolve to the new buffer - pointer_to_buffer[new_pointer] = new_buffer - # Add the old pointer to the pointer rewriting dict - rewrite_pointer[allocate_pointer] = new_pointer - - def _visit_rewrite(stmt): - if isinstance(stmt, tvm.tir.Call): - # For extern calls, we need to rewrite pairs of arguments corresponding to - # base address load and the length of the load. - new_args = [stmt.args[0]] - new_buffers = rewrite_buffer.values() - for i in range(1, len(stmt.args)): - # If the previous argument was a load, the current should be a length - if isinstance(stmt.args[i - 1], tvm.tir.Load): - load = stmt.args[i - 1] - pointer = load.buffer_var - if pointer in pointer_to_buffer: - buffer = pointer_to_buffer[pointer] - # Only rewrite the arguments of buffers that have been encoded - if buffer in new_buffers: - new_arg = np.prod(list(pointer_to_buffer[pointer].shape)) - new_args.append(new_arg) - continue - new_args.append(stmt.args[i]) - - return tvm.tir.Call(stmt.dtype, stmt.op, new_args, stmt.span) - if isinstance(stmt, tvm.tir.Allocate): - # Where a pointer needs rewriting, the allocate for it must be rewritten - allocate_pointer = stmt.buffer_var - if allocate_pointer in pointer_to_buffer: - if pointer_to_buffer[allocate_pointer] in rewrite_buffer: - new_buffer = rewrite_buffer[pointer_to_buffer[allocate_pointer]] - new_pointer = rewrite_pointer[allocate_pointer] + assert list(stmt.args[3].indices) == [0] + assert list(stmt.args[1].indices) == [0] + copied_buffers.append({"source": read_buffer, "dest": write_buffer}) + copy_map[write_buffer] = read_buffer + + else: + # Encode the weights + weights_buffer = get_weights_buffer(stmt) + if weights_buffer is not None: + if weights_buffer in copy_map: + weights_buffer = copy_map[weights_buffer] + unencoded_weights_value = old_buffer_to_const[weights_buffer] + encoded_weights_value = _encode_weights(stmt, unencoded_weights_value) + _declare_constant_buffer(weights_buffer, encoded_weights_value) + + # Align the scale_bias to 16 bytes + scale_bias_buffer = get_scale_bias_buffer(stmt) + if scale_bias_buffer is not None: + if scale_bias_buffer in copy_map: + scale_bias_buffer = copy_map[scale_bias_buffer] + scale_bias_value = old_buffer_to_const[scale_bias_buffer] + aligned_scale_bias_value = _align_scale_bias(stmt, scale_bias_value) + _declare_constant_buffer(scale_bias_buffer, aligned_scale_bias_value) + + tvm.tir.stmt_functor.post_order_visit(stmt, _visit) + + return { + "copied_buffers": copied_buffers, + "constant_buffer_replacements": constant_buffer_replacements, + } + + def transform_stmt(stmt, buf_remap, var_remap, pointer_to_buffer, new_buffer_to_const): + def _visit_rewrite(stmt): + if isinstance(stmt, tvm.tir.Call): + # For extern calls, we need to rewrite pairs of arguments corresponding to + # base address load and the length of the load. + old_args = list(stmt.args) + + new_args = [stmt.args[0]] + for prev_arg, arg in zip(old_args[:-1], old_args[1:]): + # If the previous argument was a load from an + # encoded buffer, the current should be a length. + if ( + isinstance(prev_arg, tvm.tir.BufferLoad) + and prev_arg.buffer in new_buffer_to_const + ): + arg = np.prod(list(prev_arg.buffer.shape)) + + new_args.append(arg) + + return tvm.tir.Call(stmt.dtype, stmt.op, new_args, stmt.span) + + if isinstance(stmt, tvm.tir.Allocate): + # Where a pointer needs rewriting, the allocate for it must be rewritten + allocate_pointer = stmt.buffer_var + if allocate_pointer in var_remap: + new_allocate_pointer = var_remap[allocate_pointer] + new_buffer = pointer_to_buffer[new_allocate_pointer] + return tvm.tir.Allocate( - new_pointer, + new_buffer.data, new_buffer.dtype, new_buffer.shape, stmt.condition, stmt.body, stmt.span, ) - # The following rewrites would be better expressed by just rewriting the Vars, however - # ir_transform doesn't seem to visit Vars. So instead we do the next best thing and rewrite - # the nodes which contain the Vars. - if isinstance(stmt, tvm.tir.Load): - load_pointer = stmt.buffer_var - if load_pointer in rewrite_pointer: - new_pointer = rewrite_pointer[load_pointer] - element_type = new_pointer.type_annotation.element_type.dtype - return tvm.tir.Load( - element_type, new_pointer, stmt.index, stmt.predicate, stmt.span - ) - if isinstance(stmt, tvm.tir.AttrStmt): - node_pointer = stmt.node - if node_pointer in rewrite_pointer: - return tvm.tir.AttrStmt( - rewrite_pointer[node_pointer], stmt.attr_key, stmt.value, stmt.body, stmt.span - ) - return None + + # The following rewrites would be better expressed by just + # rewriting the Buffers. However ir_transform doesn't + # visit Buffers, so instead we do the next best thing and + # rewrite the nodes which contain the Buffers. + if isinstance(stmt, tvm.tir.BufferLoad): + if stmt.buffer in buf_remap: + return tvm.tir.BufferLoad(buf_remap[stmt.buffer], stmt.indices, stmt.span) + + if isinstance(stmt, tvm.tir.AttrStmt): + node_pointer = stmt.node + if node_pointer in var_remap: + return tvm.tir.AttrStmt( + var_remap[node_pointer], + stmt.attr_key, + stmt.value, + stmt.body, + stmt.span, + ) + + return None + + return tvm.tir.stmt_functor.ir_transform( + stmt, + None, + _visit_rewrite, + ["tir.Call", "tir.Allocate", "tir.BufferLoad", "tir.AttrStmt"], + ) def _ftransform(f, mod, ctx): + # Step 0: Unpack the constant dictionary in terms of the + # functions buffers. + old_buffer_to_const = {} for i, param in enumerate(f.params): if i in const_dict: - buffer_to_const[f.buffer_map[param]] = const_dict[i].flatten() - pointer_to_buffer[f.buffer_map[param].data] = f.buffer_map[param] - - # First analyse what needs to be rewritten - new_body = tvm.tir.stmt_functor.ir_transform( - f.body, _visit_encode_pre, _visit_encode_post, ["tir.Call", "tir.Allocate"] - ) - # Then perform the rewrites - new_body = tvm.tir.stmt_functor.ir_transform( - f.body, None, _visit_rewrite, ["tir.Call", "tir.Allocate", "tir.Load", "tir.AttrStmt"] + old_buffer_to_const[f.buffer_map[param]] = const_dict[i].flatten() + + # Step 1: Collect information on the buffers that will be + # replaced by encodings. + buffer_information = collect_encoding_definitions(f.body, old_buffer_to_const) + + # Step 2: Generate variable/buffer remaps, based on the + # collected information. + buf_remap = {} + new_buffer_to_const = {} + + # Any encoded buffers must be replaced + for info in buffer_information["constant_buffer_replacements"]: + buf_remap[info["old_buffer"]] = info["new_buffer"] + new_buffer_to_const[info["new_buffer"]] = info["encoded_constants"] + + # Any buffers that are copied into from an encoded buffer must + # be replaced. + for info in buffer_information["copied_buffers"]: + copy_source = info["source"] + while copy_source in buf_remap: + copy_source = buf_remap[copy_source] + + copy_dest = info["dest"] + + if copy_source.shape != copy_dest.shape or copy_source.dtype != copy_dest.dtype: + new_dest = tvm.tir.decl_buffer( + shape=copy_source.shape, + dtype=copy_source.dtype, + name=copy_dest.name, + scope=copy_dest.scope(), + ) + buf_remap[copy_dest] = new_dest + if copy_source in new_buffer_to_const: + new_buffer_to_const[new_dest] = new_buffer_to_const[copy_source] + + # Define additional dependent lookup tables. + var_remap = {old.data: new.data for (old, new) in buf_remap.items()} + pointer_to_buffer = { + buf.data: buf for (old, new) in buf_remap.items() for buf in [old, new] + } + + # Step 3: Then perform the rewrites + new_body = transform_stmt( + f.body, buf_remap, var_remap, pointer_to_buffer, new_buffer_to_const ) + + # Step 4: Rewrite the buffer map and const dict to instead use the encoded versions new_buffer_map = {} - # Rewrite the buffer map and const dict to instead use the encoded versions for i, param in enumerate(f.params): buffer = f.buffer_map[param] - if buffer in rewrite_buffer: - new_buffer = rewrite_buffer[buffer] - new_buffer_map[param] = new_buffer - new_value = buffer_to_const[new_buffer] - new_const_dict[i] = new_value - elif buffer in buffer_to_const: - new_const_dict[i] = buffer_to_const[buffer] - new_buffer_map[param] = buffer - else: - new_buffer_map[param] = buffer - - new_f = tvm.tir.PrimFunc(f.params, new_body, f.ret_type, new_buffer_map, f.attrs, f.span) + if buffer in buf_remap: + buffer = buf_remap[buffer] + + if buffer in new_buffer_to_const: + new_const_dict[i] = new_buffer_to_const[buffer] + elif buffer in old_buffer_to_const: + new_const_dict[i] = old_buffer_to_const[buffer] + + new_buffer_map[param] = buffer + + new_f = tvm.tir.PrimFunc( + f.params, + new_body, + f.ret_type, + new_buffer_map, + f.preflattened_buffer_map, + f.attrs, + f.span, + ) return new_f def _encode_constants(mod): @@ -706,15 +772,26 @@ def CreatePrimFuncWithoutConstants(const_dict): def _ftransform(f, mod, ctx): new_params = list() new_buffer_map = dict() + new_preflattened_buffer_map = dict() for param_idx in const_dict.keys(): # We are using buffer_var to key the constants as # PrimFunc params of constants will be removed. new_const_dict[f.buffer_map[f.params[param_idx]].data] = const_dict[param_idx] - for i in range(len(f.params)): + for i, param in enumerate(f.params): if i not in const_dict.keys(): - new_params.append(f.params[i]) - new_buffer_map[f.params[i]] = f.buffer_map[f.params[i]] - return tvm.tir.PrimFunc(new_params, f.body, f.ret_type, new_buffer_map, f.attrs, f.span) + new_params.append(param) + new_buffer_map[param] = f.buffer_map[param] + if param in f.preflattened_buffer_map: + new_preflattened_buffer_map[param] = f.preflattened_buffer_map[param] + return tvm.tir.PrimFunc( + new_params, + f.body, + f.ret_type, + new_buffer_map, + new_preflattened_buffer_map, + f.attrs, + f.span, + ) def _create_primfunc_without_constants(mod): transform_func = tvm.tir.transform.prim_func_pass( diff --git a/python/tvm/relay/backend/contrib/ethosu/tir/pooling.py b/python/tvm/relay/backend/contrib/ethosu/tir/pooling.py index e929caa2409b..3b32ef01a938 100644 --- a/python/tvm/relay/backend/contrib/ethosu/tir/pooling.py +++ b/python/tvm/relay/backend/contrib/ethosu/tir/pooling.py @@ -61,8 +61,8 @@ def get_pooling_params( loads = get_loads(rw.body) # stores = [output] stores = get_stores(rw.body) - input_pointer = loads[1].buffer_var - output_pointer = stores[0].buffer_var + input_pointer = loads[1].buffer.data + output_pointer = stores[0].buffer.data # Get feature map info serial_ifm, serial_padding = get_ifm_params(input_pointer, producers) serial_ofm, replace_pointer, is_allocator = get_ofm_params(output_pointer, consumers, producers) diff --git a/python/tvm/relay/backend/contrib/ethosu/tir/spec.py b/python/tvm/relay/backend/contrib/ethosu/tir/spec.py index f9d38df9d901..d390fc0e10dc 100644 --- a/python/tvm/relay/backend/contrib/ethosu/tir/spec.py +++ b/python/tvm/relay/backend/contrib/ethosu/tir/spec.py @@ -93,10 +93,10 @@ def __init__( tile_height_0: int, tile_height_1: int, tile_width_0: int, - tile_address_0: tvm.tir.expr.Load, - tile_address_1: Union[tvm.tir.expr.Load, int], - tile_address_2: Union[tvm.tir.expr.Load, int], - tile_address_3: Union[tvm.tir.expr.Load, int], + tile_address_0: tvm.tir.expr.BufferLoad, + tile_address_1: Union[tvm.tir.expr.BufferLoad, int], + tile_address_2: Union[tvm.tir.expr.BufferLoad, int], + tile_address_3: Union[tvm.tir.expr.BufferLoad, int], scale: float, zero_point: int, layout: str, @@ -148,7 +148,7 @@ class SerialAddressRange(SerializableFormat): """Specialization class to retrieve arguments of a AddressRange (similiar to NpuAddressRange of Vela) on a predefined ordering""" - def __init__(self, address: tvm.tir.expr.Load, length: int): + def __init__(self, address: tvm.tir.expr.BufferLoad, length: int): self.address = address self.length = length @@ -237,7 +237,10 @@ class SerialCopy(SerializableFormat): a ethosu.copy tir extern call on a predefined ordering""" def __init__( - self, read_address: tvm.tir.expr.Load, length: int, write_address: tvm.tir.expr.Load + self, + read_address: tvm.tir.expr.BufferLoad, + length: int, + write_address: tvm.tir.expr.BufferLoad, ): self.read_address = read_address self.length = length diff --git a/python/tvm/relay/backend/contrib/ethosu/tir/transform.py b/python/tvm/relay/backend/contrib/ethosu/tir/transform.py index 141505a3dfba..53e0bd2a728b 100644 --- a/python/tvm/relay/backend/contrib/ethosu/tir/transform.py +++ b/python/tvm/relay/backend/contrib/ethosu/tir/transform.py @@ -50,17 +50,16 @@ def get_copy_params(stmt, producers, consumers): _, body = get_op_attrs(stmt) length = body.extent write_store = body.body - write_base = get_base_address(write_store.index) + write_base = [get_base_address(index) for index in write_store.indices] read_load = body.body.value - read_base = get_base_address(read_load.index) - dtype = body.body.value.dtype + read_base = [get_base_address(index) for index in read_load.indices] return ( SerialCopy( - read_address=tvm.tir.expr.Load(dtype, read_load.buffer_var, read_base), + read_address=tvm.tir.expr.BufferLoad(read_load.buffer, read_base), length=length, - write_address=tvm.tir.expr.Load(dtype, write_store.buffer_var, write_base), + write_address=tvm.tir.expr.BufferLoad(write_store.buffer, write_base), ), - write_store.buffer_var, + write_store.buffer.data, None, True, ) diff --git a/python/tvm/relay/backend/contrib/ethosu/tir/unary_elementwise.py b/python/tvm/relay/backend/contrib/ethosu/tir/unary_elementwise.py index b550b79e7906..9c570d88c163 100644 --- a/python/tvm/relay/backend/contrib/ethosu/tir/unary_elementwise.py +++ b/python/tvm/relay/backend/contrib/ethosu/tir/unary_elementwise.py @@ -54,11 +54,11 @@ def get_unary_elementwise_params(stmt, producers, consumers): input_pointer = None if isinstance(inner.value, tir.expr.Select): # ABS - input_pointer = inner.value.condition.b.buffer_var + input_pointer = inner.value.condition.b.buffer.data if isinstance(inner.value, tir.expr.Sub): # CLZ - input_pointer = inner.value.b.args[0].buffer_var - output_pointer = inner.buffer_var + input_pointer = inner.value.b.args[0].buffer.data + output_pointer = inner.buffer.data # Get feature map info serial_ifm, _ = get_ifm_params(input_pointer, producers) serial_ofm, replace_pointer, is_allocator = get_ofm_params(output_pointer, consumers, producers) diff --git a/python/tvm/relay/backend/contrib/ethosu/tir/utils.py b/python/tvm/relay/backend/contrib/ethosu/tir/utils.py index de1c0ab19f6e..506f18ba3a99 100644 --- a/python/tvm/relay/backend/contrib/ethosu/tir/utils.py +++ b/python/tvm/relay/backend/contrib/ethosu/tir/utils.py @@ -21,20 +21,20 @@ # TODO(@mbaret): Formalise this with a specification -def get_weights_pointer(tir_extern_call): +def get_weights_buffer(tir_extern_call): """Get the weights pointer from a NPU extern call if it exists""" supported_ops = ["ethosu_conv2d", "ethosu_depthwise_conv2d"] if tir_extern_call.args[0] in supported_ops: - return tir_extern_call.args[41].buffer_var + return tir_extern_call.args[41].buffer return None # TODO(@mbaret): Formalise this with a specification -def get_scale_bias_pointer(tir_extern_call): +def get_scale_bias_buffer(tir_extern_call): """Get the scale_bias pointer from a NPU extern call if it exists""" supported_ops = ["ethosu_conv2d", "ethosu_depthwise_conv2d"] if tir_extern_call.args[0] in supported_ops: - return tir_extern_call.args[44].buffer_var + return tir_extern_call.args[44].buffer return None @@ -177,23 +177,23 @@ def get_outer_loops(stmt, layout): def get_loads(stmt): - """Get the Load statements. + """Get the BufferLoad statements. Parameters ---------- stmt : tvm.tir.Stmt - The statement to get the Loads from. + The statement to get the BufferLoads from. Returns ------- - loads : list of tvm.tir.Load - The Loads found. + loads : list of tvm.tir.BufferLoad + The BufferLoads found. """ loads = [] def _visit(s): - if isinstance(s, tvm.tir.Load): + if isinstance(s, tvm.tir.BufferLoad): loads.append(s) tvm.tir.stmt_functor.post_order_visit(stmt, _visit) @@ -201,23 +201,23 @@ def _visit(s): def get_stores(stmt): - """Get the Store statements. + """Get the BufferStore statements. Parameters ---------- stmt : tvm.tir.Stmt - The statement to get the Stores from. + The statement to get the BufferStores from. Returns ------- - stores : list of tvm.tir.Store - The Stores found. + stores : list of tvm.tir.BufferStore + The BufferStores found. """ stores = [] def _visit(s): - if isinstance(s, tvm.tir.Store): + if isinstance(s, tvm.tir.BufferStore): stores.append(s) tvm.tir.stmt_functor.post_order_visit(stmt, _visit) diff --git a/python/tvm/relay/backend/contrib/ethosu/tir_to_cs_translator.py b/python/tvm/relay/backend/contrib/ethosu/tir_to_cs_translator.py index f642f5f7cfad..33a22d1a09fb 100644 --- a/python/tvm/relay/backend/contrib/ethosu/tir_to_cs_translator.py +++ b/python/tvm/relay/backend/contrib/ethosu/tir_to_cs_translator.py @@ -122,9 +122,9 @@ def analyze_pool_access(stmt): if isinstance(stmt, tvm.tir.stmt.LetStmt): call_address_of = stmt.value load = call_address_of.args[0] - pool_var = load.buffer_var + pool_var = load.buffer.data scratch_region_map[stmt.var] = RegionOffset( - region=pool_var_region_map[pool_var], offset=int(load.index) + region=pool_var_region_map[pool_var], offset=int(load.indices[0]) ) tvm.tir.stmt_functor.post_order_visit(primfunc.body, analyze_pool_access) @@ -334,6 +334,8 @@ def extract_buffer_info( primfunc = mod.functions.items()[0][1] for param, const_data in param_dict.items(): + if isinstance(param, tvm.tir.Buffer): + param = param.data buffer_info[param] = BufferInfo( const_data, const_data.shape, const_data.dtype, BufferType.constant ) @@ -385,6 +387,7 @@ def assign_addresses(buffer_info, npu_ops, scratch_region_map): This is the dictionary obtained via calling extract_buffer_info. The key is the buffer name to BufferInfo npu_ops : list + A list of Vela NpuOps with tir.BufferLoads for addresses A list of Vela NpuOps with tir.Loads for addresses scratch_region_map : Dict[tvm.tir.Var, RegionOffset] A buffer_var to region and offset map. @@ -397,14 +400,13 @@ def assign_addresses(buffer_info, npu_ops, scratch_region_map): """ def replace_npu_fm_with_address(npu_fm): - assert isinstance(npu_fm.tiles.addresses[0], tvm.tir.Load) + assert isinstance(npu_fm.tiles.addresses[0], tvm.tir.BufferLoad) # We currently does not support tiles # Change this when tiles are needed # (i.e. when using rolling buffers) assert npu_fm.tiles.addresses[1:] == [0, 0, 0] npu_fm.tiles.addresses[1:] = [0, 0, 0] - buffer = npu_fm.tiles.addresses[0].buffer_var - + buffer = npu_fm.tiles.addresses[0].buffer.data if buffer in scratch_region_map.keys(): address = scratch_region_map[buffer].offset region = scratch_region_map[buffer].region @@ -412,8 +414,10 @@ def replace_npu_fm_with_address(npu_fm): assert buffer in buffer_addresses.keys() address, buffer_type = buffer_addresses[buffer] region = _get_region(buffer_type) - - index = npu_fm.tiles.addresses[0].index * ( + assert ( + len(npu_fm.tiles.addresses[0].indices) == 1 + ), "Ethos-U translation expects flattened buffers" + index = npu_fm.tiles.addresses[0].indices[0] * ( np.iinfo(np.dtype(npu_fm.tiles.addresses[0])).bits // 8 ) npu_fm.tiles.addresses[0] = address + int(index) @@ -421,10 +425,11 @@ def replace_npu_fm_with_address(npu_fm): return npu_fm def replace_npu_address_range_with_address(npu_addr_range): - assert isinstance(npu_addr_range.address, tvm.tir.Load) - buffer = npu_addr_range.address.buffer_var + assert isinstance(npu_addr_range.address, tvm.tir.BufferLoad) + buffer = npu_addr_range.address.buffer.data index = int( - npu_addr_range.address.index * (np.iinfo(np.dtype(npu_addr_range.address)).bits // 8) + npu_addr_range.address.indices[0] + * (np.iinfo(np.dtype(npu_addr_range.address)).bits // 8) ) if buffer in scratch_region_map.keys(): return vapi.NpuAddressRange( @@ -446,11 +451,11 @@ def replace_tir_loads(npu_object): def classify_io(buffer): for _npu_op in npu_ops: if issubclass(type(_npu_op), vapi.NpuBlockOperation): - if _npu_op.ifm and _npu_op.ifm.tiles.addresses[0].buffer_var == buffer: + if _npu_op.ifm and _npu_op.ifm.tiles.addresses[0].buffer.data == buffer: return BufferType.input - if _npu_op.ifm2 and _npu_op.ifm2.tiles.addresses[0].buffer_var == buffer: + if _npu_op.ifm2 and _npu_op.ifm2.tiles.addresses[0].buffer.data == buffer: return BufferType.input - if _npu_op.ofm and _npu_op.ofm.tiles.addresses[0].buffer_var == buffer: + if _npu_op.ofm and _npu_op.ofm.tiles.addresses[0].buffer.data == buffer: return BufferType.output raise ValueError(f"Unused IO : {buffer} in tir module.") diff --git a/python/tvm/relay/op/_transform.py b/python/tvm/relay/op/_transform.py index fbd5c3dd07aa..c90fd683220e 100644 --- a/python/tvm/relay/op/_transform.py +++ b/python/tvm/relay/op/_transform.py @@ -200,10 +200,10 @@ def compute_unique(attrs, inputs, output_type): @script def _arange_shape_func(start, stop, step): out = output_tensor((1,), "int64") - if step[0] < 0: - out[0] = int64(ceil_div((int64(start[0]) - int64(stop[0])), int64(-step[0]))) + if step[()] < 0: + out[0] = int64(ceil_div((int64(start[()]) - int64(stop[()])), int64(-step[()]))) else: - out[0] = int64(ceil_div((int64(stop[0]) - int64(start[0])), int64(step[0]))) + out[0] = int64(ceil_div((int64(stop[()]) - int64(start[()])), int64(step[()]))) return out diff --git a/python/tvm/relay/op/dyn/_transform.py b/python/tvm/relay/op/dyn/_transform.py index c909764319d9..d523d43d9c64 100644 --- a/python/tvm/relay/op/dyn/_transform.py +++ b/python/tvm/relay/op/dyn/_transform.py @@ -170,7 +170,7 @@ def _onehot_shape_func(dshape, k, axis): out = output_tensor((ndim,), "int64") for i in const_range(axis): out[i] = int64(dshape[i]) - out[axis] = int64(k[0]) + out[axis] = int64(k[(0)]) for j in const_range(axis + 1, ndim): out[j] = int64(dshape[j - 1]) return out diff --git a/python/tvm/relay/op/dyn/nn/_nn.py b/python/tvm/relay/op/dyn/nn/_nn.py index 727715141230..ec4066561fce 100644 --- a/python/tvm/relay/op/dyn/nn/_nn.py +++ b/python/tvm/relay/op/dyn/nn/_nn.py @@ -78,8 +78,8 @@ def _upsampling_shape_func(dshape, scale_h, scale_w, height_axis, width_axis): out = output_tensor((4,), "int64") for i in const_range(4): out[i] = int64(dshape[i]) - out[height_axis] = int64(round(dshape[height_axis] * scale_h[0])) - out[width_axis] = int64(round(dshape[width_axis] * scale_w[0])) + out[height_axis] = int64(round(dshape[height_axis] * scale_h[()])) + out[width_axis] = int64(round(dshape[width_axis] * scale_w[()])) return out @@ -108,9 +108,9 @@ def _upsampling3d_shape_func( out = output_tensor((5,), "int64") for i in const_range(5): out[i] = int64(dshape[i]) - out[depth_axis] = int64(round(dshape[depth_axis] * scale_d[0])) - out[height_axis] = int64(round(dshape[height_axis] * scale_h[0])) - out[width_axis] = int64(round(dshape[width_axis] * scale_w[0])) + out[depth_axis] = int64(round(dshape[depth_axis] * scale_d[()])) + out[height_axis] = int64(round(dshape[height_axis] * scale_h[()])) + out[width_axis] = int64(round(dshape[width_axis] * scale_w[()])) return out diff --git a/python/tvm/script/context_maintainer.py b/python/tvm/script/context_maintainer.py index 149e17bcc701..972e5845fcb9 100644 --- a/python/tvm/script/context_maintainer.py +++ b/python/tvm/script/context_maintainer.py @@ -127,6 +127,8 @@ class ContextMaintainer: """List[Var]: The function parameters""" func_buffer_map: Mapping[Var, Buffer] = {} """Mapping[Var, Buffer]: The function buffer map""" + func_preflattened_buffer_map: Mapping[Var, Buffer] = {} + """Mapping[Var, Buffer]: The function buffer map, prior to any flattening.""" func_dict_attr: Mapping[str, Object] = {} """Mapping[str, Object]: The function attrs""" func_var_env_dict: Mapping[Var, str] = {} @@ -151,6 +153,7 @@ def __init__(self, _report_error: Callable[[str, Union[Span, synr.ast.Span]], No # function context self.func_params = [] self.func_buffer_map = {} + self.func_preflattened_buffer_map = {} self.func_dict_attr = {} self.func_var_env_dict = {} # parser and analyzer diff --git a/python/tvm/script/parser.py b/python/tvm/script/parser.py index 587fbe44a174..17beb8169d3b 100644 --- a/python/tvm/script/parser.py +++ b/python/tvm/script/parser.py @@ -484,6 +484,7 @@ def check_decorator(decorators: List[ast.Expr]) -> bool: body, ret_type, buffer_map=self.context.func_buffer_map, + preflattened_buffer_map=self.context.func_preflattened_buffer_map, attrs=tvm.ir.make_node("DictAttrs", **dict_attr) if dict_attr else None, span=tvm_span_from_synr(node.span), ) @@ -552,7 +553,11 @@ def transform_Assign(self, node): if isinstance(node.rhs, ast.Call): # Pattern 1 & Pattern 4 - func = self.transform(node.rhs.func_name) + if isinstance(node.rhs.func_name, ast.Op): + func = None + else: + func = self.transform(node.rhs.func_name) + if isinstance(func, WithScopeHandler): if not func.concise_scope or not func.def_symbol: self.report_error( @@ -610,6 +615,12 @@ def transform_SubscriptAssign(self, node): rhs = self.transform(node.params[2]) rhs_span = tvm_span_from_synr(node.params[2].span) if isinstance(symbol, tvm.tir.Buffer): + if len(indexes) != len(symbol.shape): + self.report_error( + f"Buffer {symbol.name} is {len(symbol.shape)}-dimensional, " + f"cannot be indexed by {len(indexes)}-dimensional indices.", + node.params[1].span, + ) # BufferStore return tvm.tir.BufferStore( symbol, @@ -629,15 +640,29 @@ def transform_SubscriptAssign(self, node): f"Store is only allowed with one index, but {len(indexes)} were provided.", node.params[1].span, ) - # Store - return tvm.tir.Store( - symbol, - tvm.runtime.convert(rhs, span=rhs_span), - indexes[0], - tvm.runtime.convert(True, span=tvm_span_from_synr(node.span)), - span=tvm_span_from_synr(node.span), + self.report_error( + "Use of tir.Store has been deprecated in favor of tir.BufferStore.", node.span + ) + + def transform_AttrAssign(self, node): + """Visitor for statements of the form :code:`x.y = 2`.""" + obj = self.transform(node.params[0]) + field = node.params[1] + value = self.transform(node.params[2]) + + if not hasattr(obj, field.name): + self.error(f"Field {field.name} does not exist", field.span) + + var = getattr(obj, field.name) + + if not isinstance(var, tvm.tir.Var): + self.error( + f"Can only assign to tir.Var attributes, not {type(var).__name__}", node.span ) + body = self.parse_body(node) + return tvm.tir.LetStmt(var, value, body, span=tvm_span_from_synr(node.span)) + def transform_Assert(self, node): """Assert visitor @@ -866,13 +891,16 @@ def f(): """ # Only allowed builtin operator that can be a statement is x[1] = 3 i.e. subscript assign. if isinstance(node.call.func_name, ast.Op): - if node.call.func_name.name != ast.BuiltinOp.SubscriptAssign: - self.report_error( - "Binary and unary operators are not allowed as a statement", node.span - ) - else: + if node.call.func_name.name == ast.BuiltinOp.SubscriptAssign: return self.transform_SubscriptAssign(node.call) + if node.call.func_name.name == ast.BuiltinOp.AttrAssign: + return self.transform_AttrAssign(node.call) + + self.report_error( + "Binary and unary operators are not allowed as a statement", node.span + ) + # handle a regular function call func = self.transform(node.call.func_name) arg_list = self.parse_arg_list(func, node.call) @@ -952,15 +980,8 @@ def transform_Subscript(self, node): node.span, ) - return call_with_error_reporting( - self.report_error, - node.span, - tvm.tir.Load, - "float32", - symbol, - index, - True, - span=tvm_span_from_synr(node.span), + self.report_error( + "Use of tir.Load has been deprecated in favor of tir.BufferLoad", node.span ) elif isinstance(symbol, tvm.tir.Buffer): return BufferSlice( diff --git a/python/tvm/script/tir/__init__.pyi b/python/tvm/script/tir/__init__.pyi index ac4ee3018f7c..0593236512a1 100644 --- a/python/tvm/script/tir/__init__.pyi +++ b/python/tvm/script/tir/__init__.pyi @@ -311,7 +311,7 @@ def allocate( scope: str, condition: Union[PrimExpr, builtins.bool] = True, annotations: Optional[Mapping[str, Object]] = None, -) -> Var: ... +) -> Buffer: ... def launch_thread(env_var: Var, extent: Union[int, PrimExpr]) -> Var: ... def realize( buffer_slice: BufferSlice, scope: str, condition: Union[PrimExpr, builtins.bool] = True diff --git a/python/tvm/script/tir/node.py b/python/tvm/script/tir/node.py index 8564fc10eecf..4dc78ba66064 100644 --- a/python/tvm/script/tir/node.py +++ b/python/tvm/script/tir/node.py @@ -96,7 +96,8 @@ def check_index(index: Union[int, PrimExpr]): if index < 0: report_error("Negative index is not allowed during buffer access", span) elif isinstance(index, PrimExpr): - if index.dtype != "int32": + element_dtype = index.dtype.split("x", maxsplit=1)[0] + if element_dtype != "int32": report_error( "index expected an int32 type PrimExpr but got " + str(index.dtype), index.span, @@ -153,3 +154,6 @@ def asobject(self) -> BufferLoad: indices = [s.start for s in self.slices] return BufferLoad(self.buffer, indices, span=self.span) + + def astype(self, dtype: str, span: Optional[Span] = None) -> PrimExpr: + return self.asobject().astype(dtype, span) diff --git a/python/tvm/script/tir/scope_handler.py b/python/tvm/script/tir/scope_handler.py index 07ba20423161..2da7b78b16cd 100644 --- a/python/tvm/script/tir/scope_handler.py +++ b/python/tvm/script/tir/scope_handler.py @@ -112,10 +112,16 @@ def allocate(extents, dtype, scope, condition=True, annotations=None, span=None) condition = tvm.runtime.convert(condition) scope = tvm.runtime.convert(scope) + # Currently, allocate nodes should only occur after buffer + # flattening has been applied. This can be simplified in + # the future by having the AllocateNode hold a buffer + # object directly. + flattened = self.buffer.get_flattened_buffer() + return tvm.tir.Allocate( - self.buffer_var, - dtype, - extents, + self.buffer.data, + flattened.dtype, + flattened.shape, condition, self.body, annotations=annotations, @@ -123,7 +129,7 @@ def allocate(extents, dtype, scope, condition=True, annotations=None, span=None) ) super().__init__(allocate, concise_scope=True, def_symbol=True) - self.buffer_var = None + self.buffer = None def enter_scope( self, @@ -147,15 +153,20 @@ def enter_scope( else: raise Exception("Internal Bug") - def setup_buffer_var( + def setup_buffer( extents, dtype, scope, condition=True, annotations=None, span: Span = None ): - """Setup buffer var for a given type.""" - buffer_ptr_type = tvm.ir.PointerType(tvm.ir.PrimType(dtype), scope) - self.buffer_var = tvm.tir.Var(name, buffer_ptr_type, span) + """Setup buffer object for a given type.""" + self.buffer = tvm.tir.decl_buffer( + shape=extents, + dtype=dtype, + name=name, + scope=scope, + span=span, + ) - setup_buffer_var(*arg_list, span=tvm_span_from_synr(var_span)) - context.update_symbol(name, self.buffer_var, node) + setup_buffer(*arg_list, span=tvm_span_from_synr(var_span)) + context.update_symbol(name, self.buffer, node) @register @@ -171,11 +182,11 @@ def allocate_const(raw_data, dtype, shape, span=None): for i in raw_data: list_data.append(i.value) nd_data = tvm.nd.array(np.asarray(list_data, dtype=dtype)) - n = tvm.tir.AllocateConst(self.buffer_var, dtype, shape, nd_data, self.body, span=span) + n = tvm.tir.AllocateConst(self.buffer.data, dtype, shape, nd_data, self.body, span=span) return n super().__init__(allocate_const, concise_scope=True, def_symbol=True) - self.buffer_var = None + self.buffer = None def enter_scope( self, @@ -199,13 +210,17 @@ def enter_scope( else: raise Exception("Internal Bug") - def setup_buffer_var(data, dtype, shape, span: Span = None): + def setup_buffer(data, dtype, shape, span: Span = None): """Setup buffer var for a given type.""" - buffer_ptr_type = tvm.ir.PointerType(tvm.ir.PrimType(dtype)) - self.buffer_var = tvm.tir.Var(name, buffer_ptr_type, span) + self.buffer = tvm.tir.decl_buffer( + shape=shape, + dtype=dtype, + name=name, + span=span, + ) - setup_buffer_var(*arg_list, span=tvm_span_from_synr(var_span)) - context.update_symbol(name, self.buffer_var, node) + setup_buffer(*arg_list, span=tvm_span_from_synr(var_span)) + context.update_symbol(name, self.buffer, node) @register diff --git a/python/tvm/script/tir/special_stmt.py b/python/tvm/script/tir/special_stmt.py index 20161ad106c1..d9c6dbda47b2 100644 --- a/python/tvm/script/tir/special_stmt.py +++ b/python/tvm/script/tir/special_stmt.py @@ -864,6 +864,60 @@ def func_attr(dict_attr, span): super().__init__(func_attr, def_symbol=False) +@register +class PreflattenedBufferMap(SpecialStmt): + """Special Stmt for declaring the PrimFunc::preflattened_buffer_map + + Example + ------- + .. code-block:: python + T.preflattened_buffer_map({}) + """ + + def __init__(self): + def preflattened_buffer( + postflattened, + shape, + dtype="float32", + data=None, + strides=None, + elem_offset=None, + scope="global", + align=-1, + offset_factor=0, + buffer_type="default", + span=None, + ): + + param = None + for key, value in self.context.func_buffer_map.items(): + if value.same_as(postflattened): + param = key + + assert ( + param is not None + ), f"Post-flatten buffer {postflattened.name} does not appear in the buffer map." + + buffer_name: str = f"{postflattened.name}_preflatten" + preflattened = tvm.tir.decl_buffer( + shape, + dtype, + buffer_name, + data, + strides, + elem_offset, + scope, + align, + offset_factor, + buffer_type, + span=span, + ) + + self.context.func_preflattened_buffer_map[param] = preflattened + + super().__init__(preflattened_buffer, def_symbol=False) + + @register class TargetAttrValue(SpecialStmt): """Special Stmt for target attr value. diff --git a/python/tvm/te/__init__.py b/python/tvm/te/__init__.py index aaad6e108e7b..4c4e223f2d72 100644 --- a/python/tvm/te/__init__.py +++ b/python/tvm/te/__init__.py @@ -27,7 +27,13 @@ from tvm.tir import div, indexdiv, indexmod, truncdiv, truncmod, floordiv, floormod from tvm.tir import comm_reducer, min, max, sum -from .schedule import Schedule, Stage, create_schedule, SpecializedCondition +from .schedule import ( + Schedule, + Stage, + create_schedule, + SpecializedCondition, + AXIS_SEPARATOR, +) from .tensor import TensorSlice, Tensor from .tensor_intrin import decl_tensor_intrin from .tag import tag_scope diff --git a/python/tvm/te/schedule.py b/python/tvm/te/schedule.py index 55d07a57e3e4..fdd08f9208c9 100644 --- a/python/tvm/te/schedule.py +++ b/python/tvm/te/schedule.py @@ -16,12 +16,16 @@ # under the License. # pylint: disable=unused-import """The computation schedule api of TVM.""" +import collections +import inspect +from typing import Callable, List + import tvm._ffi from tvm._ffi.base import string_types from tvm.runtime import Object, convert from tvm.ir import container as _container -from tvm.tir import IterVar, Buffer +from tvm.tir import IterVar, Buffer, Var from . import tensor as _tensor from . import _ffi_api @@ -519,9 +523,149 @@ def rolling_buffer(self): """ _ffi_api.StageRollingBuffer(self) + def transform_layout(self, mapping_function: Callable[..., List[tvm.tir.PrimExpr]]): + """Defines the layout transformation for the current stage's tensor. + + The map from initial_indices to final_indices must be an + invertible affine transformation. This method may be called + more than once for a given tensor, in which case each + transformation is applied sequentially. + + If the stage is a ComputeOp, then the iteration order of the + compute stage is rewritten to be a row-major traversal of the + tensor, and the new loop iteration variables are returned. + For all other stages, the loop iteration order is unmodified, + and the return value is None. + + Parameters + ---------- + mapping_function : Callable[..., List[tvm.tir.PrimExpr]] + + A callable that accepts N arguments of type tvm.tir.Var, + and outputs a list of PrimExpr. The input arguments + represent the location of a value in the current stage's + tensor, using the pre-transformation layout. The return + value of the function gives the location of that value in + the current stage's tensor, using the post-transformation + layout. + + Returns + ------- + new_iter_vars : Optional[List[tvm.tir.IterVar]] + + If the stage is a ComputeOp, then the return will be the + updated loop iteration variables over the data array, in + the same order as the output values from the + `mapping_function`. + + Otherwise, the return value is None. + + Examples + -------- + .. code-block:: python + + # ``A`` is a tensor whose compute definition is in NHWC + # format, and should be transformed into NCHWc format. + + s[A].transform_layout( + lambda n,h,w,c: [n, c//4, h, w, c%4] + ) + + + .. code-block:: python + + # ``A`` is a tensor whose compute definition is in an + # arbitrary format, and should be transformed such that + # the last index is split, with the slower-changing index + # of the split placed at the slowest changing dimension. + + s[A].transform_layout( + lambda *indices, i: [i//4, *indices, i%4] + ) + + .. code-block:: python + + # ``B`` is a tensor defined by te.compute to be a copy of + # ``A`, and should be transformed such that ``B``'s layout + # is a transpose of ``A``'s layout. The loop iteration + # that computes ``B`` will correspond to ``B``'s memory + # layout. + + A = te.placeholder([n,m]) + B = te.compute(A.shape, lambda i,j: A[i,j]) + s = te.create_schedule(B.op) + + s[B].transform_layout(lambda i,j: [j,i]) + + """ + + args = [] + var_arg_name = None + kwargs = collections.OrderedDict() + default_index_dtype = "int32" + + # Make a dummy variable for each explicitly named input index. + # We may have some keyword-only arguments, if the function has + # *args before the last argument. + params = inspect.signature(mapping_function).parameters + for name, param in params.items(): + if param.kind in [ + inspect.Parameter.POSITIONAL_ONLY, + inspect.Parameter.POSITIONAL_OR_KEYWORD, + ]: + args.append(tvm.tir.Var(name, default_index_dtype)) + + elif param.kind == inspect.Parameter.VAR_POSITIONAL: + var_arg_name = name + + elif param.kind == inspect.Parameter.KEYWORD_ONLY: + kwargs[name] = tvm.tir.Var(name, default_index_dtype) + + elif param.kind in [inspect.Parameter.VAR_KEYWORD]: + raise ValueError("transform_layout mapping may not have **kwargs") + + ndim = len(self.op.output(0).shape) + + # Now that all the named arguments have been collected, + # everything that remains should go to the *args, if + # specified. + if var_arg_name is not None: + num_var_args = ndim - len(args) - len(kwargs) + for i in range(num_var_args): + args.append(tvm.tir.Var(f"{var_arg_name}[{i}]", default_index_dtype)) + + initial_indices = args + list(kwargs.values()) + if len(initial_indices) != ndim: + raise ValueError( + f"transform_layout mapping accepts {len(params)} initial indices, " + f"but {self.op.name} is {len(self.op.shape)}-dimensional" + ) + + mapping = mapping_function(*args, **kwargs) + + final_indices = [] + axis_separators = [] + for val in mapping: + if isinstance(val, tvm.ir.PrimExpr): + final_indices.append(val) + elif val is AXIS_SEPARATOR: + axis_separators.append(len(final_indices)) + else: + raise TypeError( + "Expected mapping function to return list of " + "either tvm.ir.PrimExpr or tvm.te.AXIS_SEPARATOR. " + "Instead received {val} of type {type(val)}." + ) + + new_iter_vars = _ffi_api.StageTransformLayout(self, initial_indices, final_indices) + _ffi_api.StageSetAxisSeparators(self, axis_separators) + + return new_iter_vars or None + @tvm._ffi.register_object class SpecializedCondition(Object): + """Specialized condition to enable op specialization.""" def __init__(self, conditions): @@ -555,4 +699,10 @@ def __exit__(self, ptype, value, trace): _ffi_api.ExitSpecializationScope(self) +# Sentinel value used to indicate which groups of pre-flattening axes +# should be used to post-flattening axes axes. See +# Stage.transform_layout for more details. +AXIS_SEPARATOR = "axis_separator" + + tvm._ffi._init_api("schedule", __name__) diff --git a/python/tvm/tir/buffer.py b/python/tvm/tir/buffer.py index 6dddd7b119a0..e36a99339e48 100644 --- a/python/tvm/tir/buffer.py +++ b/python/tvm/tir/buffer.py @@ -143,6 +143,33 @@ def scope(self): """ return _ffi_api.BufferStorageScope(self) # type: ignore + def get_flattened_buffer(self): + """Generate a Buffer that is a flattened version of this buffer. + + Returns + ------- + flattened : Buffer + The corresponding flat buffer. + """ + return _ffi_api.BufferGetFlattenedBuffer(self) # type: ignore + + def offset_of(self, indices): + """Determine the offset of the provided indices in the flattened buffer. + + Parameters + ---------- + indices : Union[PrimExpr, List[PrimExpr]] + + The indices of the element in the original buffer. + + Returns + ------- + flattened_indices: List[PrimExpr] + + The offset indices of the element in the flattened buffer. + """ + return _ffi_api.BufferOffsetOf(self, indices) # type: ignore + def decl_buffer( shape, @@ -155,6 +182,7 @@ def decl_buffer( data_alignment=-1, offset_factor=0, buffer_type="", + axis_separators=None, span=None, ): """Declare a new symbolic buffer. @@ -204,6 +232,11 @@ def decl_buffer( without considering whether dimension size equals to one. TVM maps buffer[i][j][k] -> buffer[i][0][k] if dimension j's shape equals 1. + axis_separators : list of int, optional + If passed, a list of separators between groups of axes, + each of which is flattened to an output axis. For flat + memory spaces, should either be None, or an empty list. + span: Optional[Span] The location of the decl_buffer creation in the source. @@ -254,6 +287,10 @@ def decl_buffer( shape = (shape,) if isinstance(shape, (PrimExpr, Integral)) else shape dtype = "float32" if dtype is None else dtype strides = () if strides is None else strides + + if axis_separators is None: + axis_separators = [] + if offset_factor != 0 and elem_offset is None: shape_dtype = shape[0].dtype if shape and hasattr(shape[0], "dtype") else "int32" elem_offset = Var("%s_elem_offset" % name, shape_dtype) @@ -272,6 +309,7 @@ def decl_buffer( data_alignment, offset_factor, buffer_type, + axis_separators, span, ) diff --git a/python/tvm/tir/function.py b/python/tvm/tir/function.py index bcebab9ddc0a..fdee18f88cf8 100644 --- a/python/tvm/tir/function.py +++ b/python/tvm/tir/function.py @@ -45,6 +45,9 @@ class PrimFunc(BaseFunc): buffer_map : Map[tvm.tir.Var, tvm.tir.Buffer] The buffer binding map. + preflattened_buffer_map : Optional[Map[tvm.tir.Var, tvm.tir.Buffer]] + The buffer binding map, prior to any flattening. + attrs: Optional[tvm.Attrs] Attributes of the function, can be None @@ -52,9 +55,20 @@ class PrimFunc(BaseFunc): The location of this itervar in the source code. """ - def __init__(self, params, body, ret_type=None, buffer_map=None, attrs=None, span=None): + def __init__( + self, + params, + body, + ret_type=None, + buffer_map=None, + preflattened_buffer_map=None, + attrs=None, + span=None, + ): + param_list = [] buffer_map = {} if buffer_map is None else buffer_map + preflattened_buffer_map = {} if preflattened_buffer_map is None else preflattened_buffer_map for x in params: x = tvm.runtime.convert(x) if not isinstance(x, Object) else x if isinstance(x, Buffer): @@ -67,8 +81,15 @@ def __init__(self, params, body, ret_type=None, buffer_map=None, attrs=None, spa raise TypeError("params can only contain Var or Buffer") self.__init_handle_by_constructor__( - _ffi_api.PrimFunc, param_list, body, ret_type, buffer_map, attrs, span # type: ignore - ) + _ffi_api.PrimFunc, + param_list, + body, + ret_type, + buffer_map, + preflattened_buffer_map, + attrs, + span, + ) # type: ignore def with_body(self, new_body, span=None): """Create a new PrimFunc with the same set signatures but a new body. @@ -86,7 +107,15 @@ def with_body(self, new_body, span=None): new_func : PrimFunc The created new function. """ - return PrimFunc(self.params, new_body, self.ret_type, self.buffer_map, self.attrs, span) + return PrimFunc( + self.params, + new_body, + self.ret_type, + self.buffer_map, + self.preflattened_buffer_map, + self.attrs, + span, + ) def specialize(self, param_map: Mapping[Var, Union[PrimExpr, Buffer]]): """Specialize parameters of PrimFunc diff --git a/python/tvm/tir/ir_builder.py b/python/tvm/tir/ir_builder.py index a71476b23e44..334902b53229 100644 --- a/python/tvm/tir/ir_builder.py +++ b/python/tvm/tir/ir_builder.py @@ -15,12 +15,14 @@ # specific language governing permissions and limitations # under the License. """Developer API of IR node builder make function.""" +import tvm from tvm._ffi.base import string_types -from tvm.runtime import ObjectGeneric, DataType, convert, const -from tvm.ir import container as _container, PointerType, PrimType +from tvm.runtime import ObjectGeneric, convert, const +from tvm.ir import container as _container from . import stmt as _stmt from . import expr as _expr +from . import buffer as _buffer from . import op @@ -43,84 +45,77 @@ class BufferVar(ObjectGeneric): Do not create it directly, create use IRBuilder. - BufferVars support array access either via a linear index, or, if given a - shape, via a multidimensional index. + Array access through a BufferVar must use the same number of + indices as the underlying buffer was declared to have. Examples -------- In the follow example, x is BufferVar. - :code:`x[0] = ...` directly emit a store to the IRBuilder, - :code:`x[10]` translates to Load. + :code:`x[0] = ...` directly emit a BufferStore to the IRBuilder, + :code:`x[10]` translates to BufferLoad. .. code-block:: python - # The following code generate IR for x[0] = x[ + # The following code generate IR for x[0] = x[10] + 1 ib = tvm.tir.ir_builder.create() - x = ib.pointer("float32") + x = ib.allocate("float32", 20) x[0] = x[10] + 1 + # Array access using a multidimensional index y = ib.allocate("float32", (32, 32)) - # Array access using a linear index - y[(2*32) + 31] = 0. - # The same array access using a multidimensional index y[2, 31] = 0. See Also -------- IRBuilder.pointer - IRBuilder.buffer_ptr IRBuilder.allocate + """ - def __init__(self, builder, buffer_var, shape, content_type): + def __init__(self, builder, buffer, content_type): self._builder = builder - self._buffer_var = buffer_var - self._shape = shape + self._buffer = buffer self._content_type = content_type def asobject(self): - return self._buffer_var + return self._buffer @property def dtype(self): return self._content_type - def _linear_index(self, index): - if not isinstance(index, tuple) or self._shape is None: - return index - assert len(index) == len(self._shape), "Index size (%s) does not match shape size (%s)" % ( - len(index), - len(self._shape), - ) - dim_size = 1 - lidx = 0 - for dim, idx in zip(reversed(self._shape), reversed(index)): - lidx += idx * dim_size - dim_size *= dim - return lidx + def _normalize_index(self, index): + try: + index = [*index] + except TypeError: + index = [index] + + index = [x.var if isinstance(x, _expr.IterVar) else x for x in index] + + # Workaround to support previous behavior of ir_builder + # indexing by a single index, treating the buffer as if were + # already flattened. + if len(index) == 1 and len(self._buffer.shape) != 1: + index = tvm.topi.utils.unravel_index(index[0], self._buffer.shape) + + return index def __getitem__(self, index): - t = DataType(self._content_type) - index = self._linear_index(index) - if t.lanes > 1: - base = index * t.lanes - stride = 1 if (not hasattr(base, "dtype")) else const(1, base.dtype) - index = _expr.Ramp(base, stride, t.lanes) - return _expr.Load(self._content_type, self._buffer_var, index) + index = self._normalize_index(index) + return _expr.BufferLoad(self._buffer, index) def __setitem__(self, index, value): + index = self._normalize_index(index) + value = convert(value) - if value.dtype != self._content_type: + value_element = value.dtype.split("x", maxsplit=1)[0] + content_element = self._content_type.split("x", maxsplit=1)[0] + if value_element != content_element: raise ValueError( "data type does not match content type %s vs %s" % (value.dtype, self._content_type) ) - index = self._linear_index(index) - t = DataType(self._content_type) - if t.lanes > 1: - base = index * t.lanes - stride = 1 if (not hasattr(base, "dtype")) else const(1, base.dtype) - index = _expr.Ramp(base, stride, t.lanes) - self._builder.emit(_stmt.Store(self._buffer_var, value, index)) + + self._builder.emit(_stmt.BufferStore(self._buffer, value, index)) class IRBuilder(object): @@ -394,7 +389,7 @@ def let(self, var_name, value): self.emit(lambda x: _stmt.LetStmt(var, value, x)) return var - def allocate(self, dtype, shape, name="buf", scope=""): + def allocate(self, dtype, shape, name="buf", axis_separators=None, scope=""): """Create a allocate statement. Parameters @@ -408,6 +403,12 @@ def allocate(self, dtype, shape, name="buf", scope=""): name : str, optional The name of the buffer. + axis_separators : list of int, optional + + If passed, a list of separators between groups of axes, + each of which is flattened to an output axis. For flat + memory spaces, should either be None, or an empty list. + scope : str, optional The scope of the buffer. @@ -416,12 +417,18 @@ def allocate(self, dtype, shape, name="buf", scope=""): ------- buffer : BufferVar The buffer var representing the buffer. + """ - buffer_var = _expr.Var(name, PointerType(PrimType(dtype), scope)) if not isinstance(shape, (list, tuple, _container.Array)): shape = [shape] + + buffer = _buffer.decl_buffer( + shape, dtype, name, scope=scope, axis_separators=axis_separators + ) + + buffer_var = buffer.data self.emit(lambda x: _stmt.Allocate(buffer_var, dtype, shape, const(1, dtype="uint1"), x)) - return BufferVar(self, buffer_var, shape, dtype) + return BufferVar(self, buffer, dtype) def pointer(self, content_type, name="ptr", scope=""): """Create pointer variable with content type. @@ -442,10 +449,10 @@ def pointer(self, content_type, name="ptr", scope=""): ptr : BufferVar The buffer var representing the buffer. """ - buffer_var = _expr.Var(name, PointerType(PrimType(content_type), scope)) - return BufferVar(self, buffer_var, None, content_type) + buffer = _buffer.decl_buffer(shape=[1], dtype=content_type, name=name, scope=scope) + return BufferVar(self, buffer, content_type) - def buffer_ptr(self, buf, shape=None): + def buffer_ptr(self, buf): """Create pointer variable corresponds to buffer ptr. Parameters @@ -453,15 +460,12 @@ def buffer_ptr(self, buf, shape=None): buf : Buffer The buffer to be extracted. - shape : Tuple - Optional shape of the buffer. Overrides existing buffer shape. - Returns ------- ptr : BufferVar The buffer var representing the buffer. """ - return BufferVar(self, buf.data, buf.shape if shape is None else shape, buf.dtype) + return BufferVar(self, buf, buf.dtype) def likely(self, expr): """Add likely tag for expression. diff --git a/python/tvm/tir/transform/transform.py b/python/tvm/tir/transform/transform.py index 74e1f70121ef..802fdc576c41 100644 --- a/python/tvm/tir/transform/transform.py +++ b/python/tvm/tir/transform/transform.py @@ -74,6 +74,19 @@ def InjectPrefetch(): return _ffi_api.InjectPrefetch() # type: ignore +def ApplyLayoutTransforms(): + """Reshape buffers that appear in the "layout_transform_map" + fucntion attribute. + + Returns + ------- + fpass : tvm.transform.Pass + The result pass + + """ + return _ffi_api.ApplyLayoutTransforms() # type: ignore + + def StorageFlatten(cache_line_size, create_bound_attribute: bool = False): """Flatten the multi-dimensional read/write to 1D. @@ -784,7 +797,7 @@ def ExtractPrimFuncConstants(): return _ffi_api.ExtractPrimFuncConstants() # type: ignore -def RenomalizeSplitPattern(): +def RenormalizeSplitPattern(): """Renormalize the split pattern from floordiv(floormod()) to floormod(floordiv()) Returns diff --git a/python/tvm/topi/cuda/sparse.py b/python/tvm/topi/cuda/sparse.py index 32f20a15016e..8bfc8032bfef 100644 --- a/python/tvm/topi/cuda/sparse.py +++ b/python/tvm/topi/cuda/sparse.py @@ -149,7 +149,6 @@ def gen_ir(data, w_data, w_indices, w_indptr, out): warp_size = int(tvm.target.Target.current(allow_none=False).thread_warp_size) m = data.shape[1] nb = w_indptr.shape[0] - 1 - nnzb = w_data.shape[0] # treat csr like block size 1 bsr if len(w_data.shape) == 1: bs_n = 1 @@ -181,7 +180,7 @@ def gen_ir(data, w_data, w_indices, w_indptr, out): out_ptr = ib.buffer_ptr(out) data_ptr = ib.buffer_ptr(data) - w_data_ptr = ib.buffer_ptr(w_data, shape=(nnzb, bs_n, bs_k)) + w_data_ptr = ib.buffer_ptr(w_data) w_indices_ptr = ib.buffer_ptr(w_indices) w_indptr_ptr = ib.buffer_ptr(w_indptr) @@ -238,10 +237,11 @@ def gen_ir(data, w_data, w_indices, w_indptr, out): elem_idx = bb * rowlength_bi + tx with ib.for_range(0, bs_n, name="y", kind="unroll") as y: with ib.for_range(0, bs_k, name="z", kind="unroll") as z: - if use_warp_storage: - w_data_cache[tx, y, z] = w_data_ptr[row_start + elem_idx, y, z] - else: - w_data_cache[warp, tx, y, z] = w_data_ptr[row_start + elem_idx, y, z] + data_indices = [row_start + elem_idx] + ( + [y, z] if len(w_data.shape) > 1 else [] + ) + cache_indices = [tx, y, z] if use_warp_storage else [warp, tx, y, z] + w_data_cache[cache_indices] = w_data_ptr[data_indices] with ib.for_range(0, mi, name="i") as i: # thread local block matmul with ib.for_range(0, bs_m, name="x", kind="unroll") as x: diff --git a/python/tvm/topi/utils.py b/python/tvm/topi/utils.py index 0e39a6ce9a4b..af68ee905e56 100644 --- a/python/tvm/topi/utils.py +++ b/python/tvm/topi/utils.py @@ -311,9 +311,17 @@ def unravel_index(idx, shape): idxd = tvm.tir.indexdiv idxm = tvm.tir.indexmod indices = [] - for i in range(len(shape) - 1, -1, -1): - indices.append(idxm(idx, shape[i])) - idx = idxd(idx, shape[i]) + for i, dim in enumerate(reversed(shape)): + if dim == 0: + indices.append(0) + elif i == len(shape) - 1: + # Assuming the index is in-bounds, the last coordinate is + # already less than dim, and doesn't need the be remainder + # mod dim. + indices.append(idx) + else: + indices.append(idxm(idx, dim)) + idx = idxd(idx, dim) indices = indices[::-1] return indices diff --git a/src/arith/rewrite_simplify.cc b/src/arith/rewrite_simplify.cc index e11bd024bb22..732045384a95 100644 --- a/src/arith/rewrite_simplify.cc +++ b/src/arith/rewrite_simplify.cc @@ -191,7 +191,11 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const AddNode* op) { // truc div TVM_TRY_REWRITE(truncdiv(x, c1) * c1 + truncmod(x, c1), x); // floor div - TVM_TRY_REWRITE(floordiv(x, c1) * c1 + floormod(x, c1), x); + TVM_TRY_REWRITE(floordiv(x, y) * y + floormod(x, y), x); + TVM_TRY_REWRITE(y * floordiv(x, y) + floormod(x, y), x); + TVM_TRY_REWRITE(floormod(x, y) + floordiv(x, y) * y, x); + TVM_TRY_REWRITE(floormod(x, y) + y * floordiv(x, y), x); + TVM_TRY_REWRITE_IF(floordiv(floormod(x, c2) + c1, c2) + floordiv(x, c2), floordiv(x + c1, c2), c2.Eval()->value > 0); diff --git a/src/autotvm/feature_visitor.cc b/src/autotvm/feature_visitor.cc index 59cac9cc9827..17a05f024621 100644 --- a/src/autotvm/feature_visitor.cc +++ b/src/autotvm/feature_visitor.cc @@ -97,14 +97,16 @@ void FeatureVisitor::VisitStmt_(const AttrStmtNode* op) { } // memory access -void FeatureVisitor::VisitExpr_(const LoadNode* op) { - EnterMem_(op->buffer_var, op->index); +void FeatureVisitor::VisitExpr_(const BufferLoadNode* op) { + ICHECK_EQ(op->indices.size(), 1) << "FeatureVisitor can only be used on flattened buffers"; + EnterMem_(op->buffer->data, op->indices[0]); StmtExprVisitor::VisitExpr_(op); ExitMem_(); } -void FeatureVisitor::VisitStmt_(const StoreNode* op) { - EnterMem_(op->buffer_var, op->index); +void FeatureVisitor::VisitStmt_(const BufferStoreNode* op) { + ICHECK_EQ(op->indices.size(), 1) << "FeatureVisitor can only be used on flattened buffers"; + EnterMem_(op->buffer->data, op->indices[0]); StmtExprVisitor::VisitStmt_(op); ExitMem_(); } diff --git a/src/autotvm/feature_visitor.h b/src/autotvm/feature_visitor.h index 8180839b0668..3d34882c77db 100644 --- a/src/autotvm/feature_visitor.h +++ b/src/autotvm/feature_visitor.h @@ -66,8 +66,8 @@ class FeatureVisitor : public StmtExprVisitor { void VisitStmt_(const AttrStmtNode* op) final; // memory access - void VisitExpr_(const LoadNode* op) final; - void VisitStmt_(const StoreNode* op) final; + void VisitExpr_(const BufferLoadNode* op) final; + void VisitStmt_(const BufferStoreNode* op) final; using StmtExprVisitor::VisitExpr_; using StmtExprVisitor::VisitStmt_; diff --git a/src/contrib/hybrid/codegen_hybrid.cc b/src/contrib/hybrid/codegen_hybrid.cc index 5872a49968cb..24c7ee74cdcf 100644 --- a/src/contrib/hybrid/codegen_hybrid.cc +++ b/src/contrib/hybrid/codegen_hybrid.cc @@ -274,6 +274,14 @@ void CodeGenHybrid::VisitExpr_(const LoadNode* op, std::ostream& os) { // NOLIN void CodeGenHybrid::VisitStmt_(const StoreNode* op) { LOG(FATAL) << "Phase 0 has no Store(s)!"; } +void CodeGenHybrid::VisitExpr_(const BufferLoadNode* op, std::ostream& os) { // NOLINT(*) + LOG(FATAL) << "Phase 0 has no BufferLoad(s)!"; +} + +void CodeGenHybrid::VisitStmt_(const BufferStoreNode* op) { + LOG(FATAL) << "Phase 0 has no BufferStore(s)!"; +} + void CodeGenHybrid::VisitExpr_(const LetNode* op, std::ostream& os) { // NOLINT(*) LOG(FATAL) << "Phase 0 has no Let(s)!"; } diff --git a/src/contrib/hybrid/codegen_hybrid.h b/src/contrib/hybrid/codegen_hybrid.h index 47c13f73022f..da45ffb6a8ce 100644 --- a/src/contrib/hybrid/codegen_hybrid.h +++ b/src/contrib/hybrid/codegen_hybrid.h @@ -89,6 +89,7 @@ class CodeGenHybrid : public ExprFunctor, // expression void VisitExpr_(const VarNode* op, std::ostream& os) override; // NOLINT(*) void VisitExpr_(const LoadNode* op, std::ostream& os) override; // NOLINT(*) + void VisitExpr_(const BufferLoadNode* op, std::ostream& os) override; // NOLINT(*) void VisitExpr_(const LetNode* op, std::ostream& os) override; // NOLINT(*) void VisitExpr_(const CallNode* op, std::ostream& os) override; // NOLINT(*) void VisitExpr_(const ProducerLoadNode* op, std::ostream& os) override; // NOLINT(*) @@ -120,6 +121,7 @@ class CodeGenHybrid : public ExprFunctor, // statment void VisitStmt_(const LetStmtNode* op) override; void VisitStmt_(const StoreNode* op) override; + void VisitStmt_(const BufferStoreNode* op) override; void VisitStmt_(const ProducerStoreNode* op) override; void VisitStmt_(const ForNode* op) override; void VisitStmt_(const IfThenElseNode* op) override; diff --git a/src/printer/tir_text_printer.cc b/src/printer/tir_text_printer.cc index e229da4c26d9..16d477232eb2 100644 --- a/src/printer/tir_text_printer.cc +++ b/src/printer/tir_text_printer.cc @@ -223,6 +223,9 @@ Doc TIRTextPrinter::BufferNode2Doc(const BufferNode* buf, Doc doc) { if (!is_zero(buf->elem_offset)) { doc << ", elem_offset=" << Print(buf->elem_offset); } + if (buf->axis_separators.size()) { + doc << ", axis_separators=" << Print(buf->axis_separators); + } if (GetRef(buf).scope() != "global") { doc << ", scope=" << Doc::StrLiteral(GetRef(buf).scope()); } diff --git a/src/printer/tvmscript_printer.cc b/src/printer/tvmscript_printer.cc index e1ccd2f5e428..a6e506612fb6 100644 --- a/src/printer/tvmscript_printer.cc +++ b/src/printer/tvmscript_printer.cc @@ -265,7 +265,7 @@ class TVMScriptPrinter : public StmtFunctor, Doc PrintRange(const RangeNode* op); Doc PrintArray(const ArrayNode* op); Doc PrintBuffer(const BufferNode* op); - Doc PrintNonHeaderBufferDeclarations(Var buffer_var, Stmt body); + Doc PrintNonHeaderBufferDeclarations(const Array& aliasing_buffers); Doc AllocBufferDeclaration(const Buffer& buf); Doc PrintBlockVar(const IterVar& iter_var, const PrimExpr& value); Doc PrintBlockVarRemaps(); @@ -912,16 +912,21 @@ Doc TVMScriptPrinter::VisitExpr_(const ReduceNode* op, ExprPrecedence* out_prece } Doc TVMScriptPrinter::VisitStmt_(const LetStmtNode* op) { + if (!buffer_var_usage_.count(op->var)) { + buffer_var_usage_ = BufferUsageFinder::FindUsage(std::move(buffer_var_usage_), op->body); + } + Array buffer_usage = buffer_var_usage_.Get(op->var).value_or({}); + Doc doc; if (current_num_ != num_child_ - 1) { doc << "with " << tir_prefix_ << ".let(" << Print(op->var) << ", " << Print(op->value) << "):"; - doc << Doc::Indent(4, Doc::NewLine() << PrintNonHeaderBufferDeclarations(op->var, op->body) - << PrintBody(op->body)); + doc << Doc::Indent( + 4, Doc::NewLine() << PrintNonHeaderBufferDeclarations(buffer_usage) << PrintBody(op->body)); } else { if (memo_var_.find(op->var) == memo_var_.end()) var_not_in_headers_.insert(op->var.get()); doc << Print(op->var) << ": " << Print(GetType(op->var)) << " = " << Print(op->value) << Doc::NewLine(); - doc << PrintNonHeaderBufferDeclarations(op->var, op->body) << PrintBody(op->body); + doc << PrintNonHeaderBufferDeclarations(buffer_usage) << PrintBody(op->body); } return doc; } @@ -1008,8 +1013,59 @@ Doc TVMScriptPrinter::VisitStmt_(const BufferRealizeNode* op) { return Doc(); } +namespace { +struct AllocUsage { + Buffer alloc_buffer; + Array aliasing_buffers; +}; + +template +AllocUsage FindAllocateUsage(AllocNode* op, Map>* cache_ptr) { + Map>& cache = *cache_ptr; + if (!cache.count(op->buffer_var)) { + cache = BufferUsageFinder::FindUsage(std::move(cache), op->body); + } + Array buffer_usage = cache.Get(op->buffer_var).value_or({}); + + auto is_exact_match = [](Buffer a, Buffer b) { + if (a->dtype != b->dtype) return false; + if (a->shape.size() != b->shape.size()) return false; + + arith::Analyzer analyzer; + for (size_t i = 0; i < a->shape.size(); i++) { + if (!analyzer.CanProveEqual(a->shape[i], b->shape[i])) { + return false; + } + } + return true; + }; + + // If the buffer allocated via T.allocate is an exact match to the + // usage of the buffer later on, then that buffer is the return + // value of T.allocate, and no T.buffer_decl statement is needed. + Buffer alloc_buffer(op->buffer_var, op->dtype, op->extents, {}, 0, op->buffer_var->name_hint, 0, + 0, kDefault); + bool found_alloc_buf = false; + Array aliasing_buffers; + for (const auto& buf : buffer_usage) { + if (!found_alloc_buf && is_exact_match(buf, alloc_buffer)) { + alloc_buffer = buf; + found_alloc_buf = true; + } else { + aliasing_buffers.push_back(buf); + } + } + + return AllocUsage{alloc_buffer, aliasing_buffers}; +} +} // namespace + Doc TVMScriptPrinter::VisitStmt_(const AllocateNode* op) { - var_not_in_headers_.insert(op->buffer_var.get()); + auto usage = FindAllocateUsage(op, &buffer_var_usage_); + Buffer& alloc_buffer = usage.alloc_buffer; + Array& aliasing_buffers = usage.aliasing_buffers; + buf_not_in_headers_.insert(alloc_buffer.get()); + var_not_in_headers_.insert(alloc_buffer->data.get()); auto storage_scope = GetPtrStorageScope(op->buffer_var); Doc func_call; @@ -1027,13 +1083,12 @@ Doc TVMScriptPrinter::VisitStmt_(const AllocateNode* op) { Doc doc; if (current_num_ != num_child_ - 1) { - doc << "with " << func_call << " as " << Print(op->buffer_var) << ":"; - doc << Doc::Indent(4, Doc::NewLine() - << PrintNonHeaderBufferDeclarations(op->buffer_var, op->body) - << PrintBody(op->body)); + doc << "with " << func_call << " as " << Print(alloc_buffer) << ":"; + doc << Doc::Indent(4, Doc::NewLine() << PrintNonHeaderBufferDeclarations(aliasing_buffers) + << PrintBody(op->body)); } else { - doc << Print(op->buffer_var) << " = " << func_call << Doc::NewLine(); - doc << PrintNonHeaderBufferDeclarations(op->buffer_var, op->body) << PrintBody(op->body); + doc << Print(alloc_buffer) << " = " << func_call << Doc::NewLine(); + doc << PrintNonHeaderBufferDeclarations(aliasing_buffers) << PrintBody(op->body); } TryDeallocVar(op->buffer_var); return doc; @@ -1069,16 +1124,25 @@ Doc TVMScriptPrinter::VisitStmt_(const AllocateConstNode* alloc) { } auto ndarray_str = ss.str(); + auto usage = FindAllocateUsage(alloc, &buffer_var_usage_); + Buffer& alloc_buffer = usage.alloc_buffer; + Array& aliasing_buffers = usage.aliasing_buffers; + buf_not_in_headers_.insert(alloc_buffer.get()); + var_not_in_headers_.insert(alloc_buffer->data.get()); + + Doc func_call; + func_call << tir_prefix_ << ".allocate_const(" << ndarray_str << ", " << PrintDType(alloc->dtype) + << ", " << Print(alloc->extents) << ")"; + Doc doc; var_not_in_headers_.insert(alloc->buffer_var.get()); if (current_num_ != num_child_ - 1) { - doc << "with tir.allocate_const(" << ndarray_str << ", " << PrintDType(alloc->dtype) << ", " - << Print(alloc->extents) << ")"; - doc << Doc::Indent(4, Doc::NewLine() << PrintBody(alloc->body)); + doc << "with " << func_call << " as " << Print(alloc_buffer) << ":"; + doc << Doc::Indent(4, Doc::NewLine() << PrintNonHeaderBufferDeclarations(aliasing_buffers) + << PrintBody(alloc->body)); } else { - doc << Print(alloc->buffer_var) << " = tir.allocate_const(" << ndarray_str << ", " - << PrintDType(alloc->dtype) << ", " << Print(alloc->extents); - doc << ")" << Doc::NewLine() << PrintBody(alloc->body); + doc << Print(alloc_buffer) << " = " << func_call << Doc::NewLine(); + doc << PrintNonHeaderBufferDeclarations(aliasing_buffers) << PrintBody(alloc->body); } return doc; } @@ -1465,9 +1529,30 @@ Doc TVMScriptPrinter::PrintPrimFunc(const PrimFunc& primFunc) { if (simple_buf.count(buf)) continue; buf_not_in_headers_.insert(buf.get()); body << Print(buf) << " = " << tir_prefix_ << ".match_buffer("; + ICHECK(memo_buf_decl_.count(buf)); body << Print((*it).first) << ", " << memo_buf_decl_[buf]; body << ")" << Doc::NewLine(); } + // print preflattened buffer map + for (const auto& param : op->params) { + auto pf_buf_it = op->preflattened_buffer_map.find(param); + if (pf_buf_it != op->preflattened_buffer_map.end()) { + const Buffer& preflattened = (*pf_buf_it).second; + + auto buf_it = op->buffer_map.find(param); + ICHECK(buf_it != op->buffer_map.end()) << "Found pre-flattened buffer " << preflattened->name + << " with no corresponding post-flatten buffer."; + const Buffer& postflattened = (*buf_it).second; + + // Call Print() without assigning in order to fill memo_buf_decl_. + Print(preflattened); + buf_not_in_headers_.insert(preflattened.get()); + ICHECK(memo_buf_decl_.count(preflattened)); + + body << tir_prefix_ << ".preflattened_buffer(" << Print(postflattened) << ", " + << memo_buf_decl_.at(preflattened) << ")" << Doc::NewLine(); + } + } // print body body << "# body" << Doc::NewLine(); if (op->body->IsInstance() && @@ -1586,13 +1671,9 @@ Doc TVMScriptPrinter::PrintBuffer(const BufferNode* op) { return meta_.InMeta(buffer) ? meta_.GetMetaNode(buffer) : AllocBuf(buffer); } -Doc TVMScriptPrinter::PrintNonHeaderBufferDeclarations(Var buffer_var, Stmt body) { - if (!buffer_var_usage_.count(buffer_var)) { - buffer_var_usage_ = BufferUsageFinder::FindUsage(std::move(buffer_var_usage_), body); - } - Array buffer_usage = buffer_var_usage_.Get(buffer_var).value_or({}); +Doc TVMScriptPrinter::PrintNonHeaderBufferDeclarations(const Array& aliasing_buffers) { Doc decls; - for (const auto& buf_usage : buffer_usage) { + for (const auto& buf_usage : aliasing_buffers) { decls << Print(buf_usage) << " = " << tir_prefix_ << ".buffer_decl(" << memo_buf_decl_[buf_usage] << ")" << Doc::NewLine(); buf_not_in_headers_.insert(buf_usage.get()); diff --git a/src/relay/backend/aot_executor_codegen.cc b/src/relay/backend/aot_executor_codegen.cc index 3d2f0fcaa2d0..0629ccd2ee19 100644 --- a/src/relay/backend/aot_executor_codegen.cc +++ b/src/relay/backend/aot_executor_codegen.cc @@ -423,15 +423,17 @@ class AOTExecutorCodegen : public MixedModeVisitor { */ void CopyToOutput(PrimExpr out, PrimExpr in, bool pack_input, size_t size) { // Define intermediate DLTensor to load/store the data - auto tmp0 = te::Var("tmp0", DataType::Handle()); - auto tmp1 = te::Var("tmp1", DataType::Handle()); + tir::Buffer tmp_read = + tir::decl_buffer({IntImm(DataType::UInt(64), size)}, DataType::UInt(8), "tmp_read"); + tir::Buffer tmp_write = + tir::decl_buffer({IntImm(DataType::UInt(64), size)}, DataType::UInt(8), "tmp_write"); te::Var loop_idx("i", DataType::Int(32)); - auto retval_i = tir::Load(DataType::UInt(8), tmp0, loop_idx, tir::const_true()); + auto retval_i = tir::BufferLoad(tmp_read, {loop_idx}); // Copy the variable from the input to the output tir::Stmt copy = tir::For(loop_idx, 0, ConstInt32(size), tir::ForKind::kSerial, - tir::Store(tmp1, tir::Let(tmp0, in, retval_i), loop_idx, tir::const_true())); - stmts_.push_back(tir::LetStmt(tmp1, out, copy)); + tir::BufferStore(tmp_write, tir::Let(tmp_read->data, in, retval_i), {loop_idx})); + stmts_.push_back(tir::LetStmt(tmp_write->data, out, copy)); } /* @@ -689,7 +691,7 @@ class AOTExecutorCodegen : public MixedModeVisitor { tir::Stmt final_body = tir::SeqStmt({device_activations, body, device_deactivations}); // Make the PrimFunc - return tir::PrimFunc(main_signature_, final_body, VoidType(), main_buffer_map_, + return tir::PrimFunc(main_signature_, final_body, VoidType(), main_buffer_map_, {}, DictAttrs(dict_attrs)); } diff --git a/src/relay/backend/contrib/cmsisnn/relay_to_tir.cc b/src/relay/backend/contrib/cmsisnn/relay_to_tir.cc index 530d6495adb2..46eacec13b99 100644 --- a/src/relay/backend/contrib/cmsisnn/relay_to_tir.cc +++ b/src/relay/backend/contrib/cmsisnn/relay_to_tir.cc @@ -108,7 +108,7 @@ class RelayToTIRVisitor : public MixedModeMutator { } tir::PrimFunc replacement_func(func_signature, body, VoidType(), buffer_map, - DictAttrs(dict_attrs)); + Map(), DictAttrs(dict_attrs)); ir_module_->Add(global_var, replacement_func); } diff --git a/src/relay/backend/contrib/example_target_hooks/relay_to_tir.cc b/src/relay/backend/contrib/example_target_hooks/relay_to_tir.cc index 6794594b5ba4..86f55caf9342 100644 --- a/src/relay/backend/contrib/example_target_hooks/relay_to_tir.cc +++ b/src/relay/backend/contrib/example_target_hooks/relay_to_tir.cc @@ -52,8 +52,8 @@ class ConvertAddToSubtract : public MixedModeMutator { } private: - tir::Load LoadIndex(const tir::Buffer& buffer, const PrimExpr& index) { - return tir::Load(DataType::Float(32), buffer->data, index, tir::const_true()); + tir::BufferLoad LoadIndex(const tir::Buffer& buffer, const PrimExpr& index) { + return tir::BufferLoad(buffer, {index}); } void ReplaceAddWithSubtractPrimFunc(const GlobalVar& new_global_var, const Function& func) { @@ -71,7 +71,7 @@ class ConvertAddToSubtract : public MixedModeMutator { te::Var index("index", DataType::Int(32)); tir::Sub indexed_sub = tir::Sub(LoadIndex(x_buffer, index), LoadIndex(y_buffer, index)); - tir::Stmt math_body = tir::Store(out_buffer->data, indexed_sub, index, tir::const_true()); + tir::Stmt math_body = tir::BufferStore(out_buffer, indexed_sub, {index}); tir::Stmt math_loop = tir::For(index, 0, 8, tir::ForKind::kSerial, math_body); Map buffer_map = { @@ -81,7 +81,7 @@ class ConvertAddToSubtract : public MixedModeMutator { }; tir::PrimFunc replacement_func = tir::PrimFunc({x_var, y_var, out_var}, math_loop, VoidType(), - buffer_map, DictAttrs(dict_attrs)); + buffer_map, {}, DictAttrs(dict_attrs)); // Switch to TIRToRuntime hook for testing Bool tir_to_runtime = func->GetAttr("tir_to_runtime").value_or(Bool(false)); diff --git a/src/relay/op/tensor/transform.cc b/src/relay/op/tensor/transform.cc index b5316f2b7bca..3f7da4e954a4 100644 --- a/src/relay/op/tensor/transform.cc +++ b/src/relay/op/tensor/transform.cc @@ -1571,11 +1571,15 @@ inline te::Tensor DynamicArange(const te::Tensor& start, const te::Tensor& stop, const te::Tensor& step, tvm::DataType dtype, std::string name = "T_arange_dynamic", std::string tag = topi::kInjective) { + ICHECK_EQ(start.ndim(), 0); + ICHECK_EQ(stop.ndim(), 0); + ICHECK_EQ(step.ndim(), 0); tvm::PrimExpr num_elem = tvm::tir::Var("num_elem"); return te::compute( {num_elem}, [&](const Array& indices) { - return tvm::cast(dtype, start[0] + step[0] * indices[0]); + Array empty_indices; + return tvm::cast(dtype, start(empty_indices) + step(empty_indices) * indices[0]); }, name, tag); } diff --git a/src/relay/transforms/fold_constant.cc b/src/relay/transforms/fold_constant.cc index 3a8391e05856..a078cabda3f6 100644 --- a/src/relay/transforms/fold_constant.cc +++ b/src/relay/transforms/fold_constant.cc @@ -253,6 +253,7 @@ class ConstantFolder : public MixedModeMutator { // Use a fresh build context in case we are already in a build context. // needed for both execution and creation(due to JIT) With fresh_build_ctx(transform::PassContext::Create()); + Map dict = (module_->attrs.defined()) ? module_->attrs->dict : Map(); Expr result = ObjectToExpr(Eval(expr, module_->type_definitions, module_->Imports(), diff --git a/src/target/llvm/codegen_cpu.cc b/src/target/llvm/codegen_cpu.cc index 6d9d98072ee6..ded346eaaf36 100644 --- a/src/target/llvm/codegen_cpu.cc +++ b/src/target/llvm/codegen_cpu.cc @@ -810,11 +810,13 @@ CodeGenCPU::PackedCall CodeGenCPU::MakeCallPackedLowered(const Array& llvm::Value* arg_value = builder_->CreateInBoundsGEP( t_tvm_value_, builder_->CreatePointerCast(stack_value, t_tvm_value_->getPointerTo()), ConstInt32(begin)); - TypedPointer arg_tcode = CreateBufferPtr(DataType::Int(32), stack_tcode, ConstInt32(begin)); + TypedPointer arg_tcode = + CreateBufferPtr(stack_tcode, DataType::Int(32), ConstInt32(begin), DataType::Int(32)); llvm::Value* ret_value = builder_->CreateInBoundsGEP( t_tvm_value_, builder_->CreatePointerCast(stack_value, t_tvm_value_->getPointerTo()), ConstInt32(end)); - TypedPointer ret_tcode = CreateBufferPtr(DataType::Int(32), stack_tcode, ConstInt32(end)); + TypedPointer ret_tcode = + CreateBufferPtr(stack_tcode, DataType::Int(32), ConstInt32(end), DataType::Int(32)); #if TVM_LLVM_VERSION >= 90 auto call_callee = llvm::FunctionCallee(ftype_tvm_func_call_, RuntimeTVMFuncCall()); diff --git a/src/target/llvm/codegen_hexagon.cc b/src/target/llvm/codegen_hexagon.cc index 496c73afa4f5..32587030ba17 100644 --- a/src/target/llvm/codegen_hexagon.cc +++ b/src/target/llvm/codegen_hexagon.cc @@ -319,11 +319,13 @@ CodeGenHexagon::PackedCall CodeGenHexagon::MakeCallPackedLowered(const ArrayCreateInBoundsGEP( t_tvm_value_, builder_->CreatePointerCast(stack_value, t_tvm_value_->getPointerTo()), ConstInt32(begin)); - TypedPointer arg_tcode = CreateBufferPtr(DataType::Int(32), stack_tcode, ConstInt32(begin)); + TypedPointer arg_tcode = + CreateBufferPtr(stack_tcode, DataType::Int(32), ConstInt32(begin), DataType::Int(32)); llvm::Value* ret_value = builder_->CreateInBoundsGEP( t_tvm_value_, builder_->CreatePointerCast(stack_value, t_tvm_value_->getPointerTo()), ConstInt32(end)); - TypedPointer ret_tcode = CreateBufferPtr(DataType::Int(32), stack_tcode, ConstInt32(end)); + TypedPointer ret_tcode = + CreateBufferPtr(stack_tcode, DataType::Int(32), ConstInt32(end), DataType::Int(32)); #if TVM_LLVM_VERSION >= 90 auto call_callee = llvm::FunctionCallee(ftype_tvm_func_call_, RuntimeTVMFuncCall()); diff --git a/src/target/llvm/codegen_llvm.cc b/src/target/llvm/codegen_llvm.cc index 0545d0b4a198..cc2e495f6e37 100644 --- a/src/target/llvm/codegen_llvm.cc +++ b/src/target/llvm/codegen_llvm.cc @@ -437,6 +437,13 @@ llvm::Type* CodeGenLLVM::GetLLVMType(const Type& type) const { if (auto* ptr = type.as()) { return DTypeToLLVMType(ptr->dtype); } else if (auto* ptr = type.as()) { + // LLVM IR doesn't allow void*, so we need to recognize this + // pattern explicitly. + if (auto* primtype = ptr->element_type.as()) { + if (primtype->dtype.is_void()) { + return t_void_p_; + } + } // TODO(tvm-team) consider put storage scope into the pointer type. return GetLLVMType(ptr->element_type)->getPointerTo(GetGlobalAddressSpace()); } else if (IsVoidType(type)) { @@ -781,17 +788,35 @@ llvm::Constant* CodeGenLLVM::GetConstString(const std::string& str) { return ptr; } -CodeGenLLVM::TypedPointer CodeGenLLVM::CreateBufferPtr(DataType t, llvm::Value* buffer, - llvm::Value* index) { - llvm::PointerType* btype = llvm::dyn_cast(buffer->getType()); - ICHECK(btype != nullptr); - llvm::Type* llvm_type = DTypeToLLVMType(t); - llvm::PointerType* ttype = llvm_type->getPointerTo(btype->getAddressSpace()); - if (btype != ttype) { - buffer = builder_->CreatePointerCast(buffer, ttype); +CodeGenLLVM::TypedPointer CodeGenLLVM::CreateBufferPtr(llvm::Value* buffer_ptr, + DataType buffer_element_dtype, + llvm::Value* index, DataType value_dtype) { + llvm::PointerType* buffer_ptr_type = llvm::dyn_cast(buffer_ptr->getType()); + ICHECK(buffer_ptr_type != nullptr); + auto address_space = buffer_ptr_type->getAddressSpace(); + + llvm::Type* element_type = DTypeToLLVMType(buffer_element_dtype); + llvm::PointerType* element_ptr_type = + DTypeToLLVMType(buffer_element_dtype)->getPointerTo(address_space); + llvm::Type* value_type = DTypeToLLVMType(value_dtype); + llvm::PointerType* value_ptr_type = value_type->getPointerTo(address_space); + + ICHECK(index->getType()->isIntegerTy()) << "Expected buffer index to be an integer"; + + if (buffer_ptr_type != element_ptr_type) { + buffer_ptr = builder_->CreatePointerCast(buffer_ptr, element_ptr_type); } - llvm::Value* ptr = builder_->CreateInBoundsGEP(llvm_type, buffer, index); - return TypedPointer(llvm_type, ptr); + ICHECK(!HasAlignmentPadding(buffer_element_dtype)) + << "DType " << buffer_element_dtype + << " has padding for alignment. TVM data arrays are expected to be densely packed, with no " + "padding for alignment."; + llvm::Value* value_ptr = builder_->CreateInBoundsGEP(element_type, buffer_ptr, index); + + if (element_ptr_type != value_ptr_type) { + value_ptr = builder_->CreatePointerCast(value_ptr, value_ptr_type); + } + + return TypedPointer(value_type, value_ptr); } llvm::Value* CodeGenLLVM::GetVarValue(const VarNode* v) const { @@ -976,15 +1001,15 @@ llvm::Value* CodeGenLLVM::CreateIntrinsic(const CallNode* op) { } else if (op->op.same_as(builtin::tvm_storage_sync())) { return CreateStorageSync(op); } else if (op->op.same_as(builtin::address_of())) { - const LoadNode* l = op->args[0].as(); - ICHECK(op->args.size() == 1 && l); - TypedPointer buffer_ptr; - if (const RampNode* r = l->index.as()) { - PrimExpr index = r->base / make_const(DataType::Int(32), r->lanes); - buffer_ptr = CreateBufferPtr(l->dtype, MakeValue(l->buffer_var), MakeValue(index)); - } else { - buffer_ptr = CreateBufferPtr(l->dtype, MakeValue(l->buffer_var), MakeValue(l->index)); + const BufferLoadNode* load = op->args[0].as(); + ICHECK(op->args.size() == 1 && load); + ICHECK_EQ(load->indices.size(), 1) << "LLVM only supports flat memory allocations."; + PrimExpr index = load->indices[0]; + if (const RampNode* r = index.as()) { + index = r->base; } + TypedPointer buffer_ptr = CreateBufferPtr(MakeValue(load->buffer->data), load->buffer->dtype, + MakeValue(index), load->dtype); unsigned addrspace = llvm::dyn_cast(buffer_ptr.addr->getType())->getAddressSpace(); return builder_->CreatePointerCast(buffer_ptr.addr, t_char_->getPointerTo(addrspace)); @@ -1236,15 +1261,40 @@ llvm::Value* CodeGenLLVM::VisitExpr_(const LetNode* op) { } llvm::Value* CodeGenLLVM::VisitExpr_(const LoadNode* op) { + LOG(FATAL) << "Unexpected deprecated LoadNode. Use BufferLoadNode instead."; + return NULL; +} + +bool CodeGenLLVM::HasAlignmentPadding(DataType dtype) { + const llvm::DataLayout& data_layout = module_->getDataLayout(); + int bytes = data_layout.getTypeAllocSize(DTypeToLLVMType(dtype)); + int bytes_scalar = data_layout.getTypeAllocSize(DTypeToLLVMType(dtype.element_of())); + return bytes != bytes_scalar * dtype.lanes(); +} + +llvm::Value* CodeGenLLVM::VisitExpr_(const BufferLoadNode* op) { + ICHECK_EQ(op->indices.size(), 1) << "CodeGenLLVM expects flattened 1-d buffers."; + DataType t = op->dtype; - bool is_volatile = volatile_buf_.count(op->buffer_var.get()); - llvm::Value* buffer = MakeValue(op->buffer_var); - llvm::Value* index = MakeValue(op->index); + DataType buffer_element_dtype = op->buffer->dtype; + Var buffer_var = op->buffer->data; + PrimExpr buffer_index = op->indices[0]; - if (t.lanes() == 1) { + bool is_volatile = volatile_buf_.count(buffer_var.get()); + + if (t.lanes() == buffer_element_dtype.lanes()) { int alignment, native_bits; - GetAlignment(t, op->buffer_var.get(), op->index, &alignment, &native_bits); - TypedPointer buffer_ptr = CreateBufferPtr(t, buffer, index); + GetAlignment(t, buffer_var.get(), buffer_index, &alignment, &native_bits); + + TypedPointer buffer_ptr; + if (HasAlignmentPadding(buffer_element_dtype)) { + buffer_ptr = CreateBufferPtr(MakeValue(op->buffer->data), buffer_element_dtype.element_of(), + MakeValue(buffer_element_dtype.lanes() * buffer_index), t); + } else { + buffer_ptr = CreateBufferPtr(MakeValue(op->buffer->data), buffer_element_dtype, + MakeValue(buffer_index), t); + } + #if TVM_LLVM_VERSION >= 110 llvm::LoadInst* load = builder_->CreateAlignedLoad(buffer_ptr.type, buffer_ptr.addr, llvm::Align(alignment), is_volatile); @@ -1254,22 +1304,18 @@ llvm::Value* CodeGenLLVM::VisitExpr_(const LoadNode* op) { #else llvm::LoadInst* load = builder_->CreateAlignedLoad(buffer_ptr.addr, alignment, is_volatile); #endif - AddAliasInfo(load, op->buffer_var.get(), op->index); + AddAliasInfo(load, buffer_var.get(), buffer_index); return load; } else { // vector load - if (const RampNode* ramp = op->index.as()) { + if (const RampNode* ramp = buffer_index.as()) { if (is_one(ramp->stride)) { int alignment, native_bits; - GetAlignment(t, op->buffer_var.get(), ramp->base, &alignment, &native_bits); - ICHECK_EQ(ramp->lanes, t.lanes()); + GetAlignment(t, buffer_var.get(), ramp->base, &alignment, &native_bits); + ICHECK_EQ(ramp->lanes * buffer_element_dtype.lanes(), t.lanes()); // The index argument is element-based, to create buffer pointer for t's element type. - TypedPointer buffer_ptr = CreateBufferPtr(t.element_of(), buffer, MakeValue(ramp->base)); - unsigned addrspace = - llvm::dyn_cast(buffer->getType())->getAddressSpace(); - buffer_ptr.type = DTypeToLLVMType(t); - buffer_ptr.addr = - builder_->CreatePointerCast(buffer_ptr.addr, buffer_ptr.type->getPointerTo(addrspace)); + TypedPointer buffer_ptr = CreateBufferPtr(MakeValue(op->buffer->data), op->buffer->dtype, + MakeValue(ramp->base), t); #if TVM_LLVM_VERSION >= 110 llvm::LoadInst* load = builder_->CreateAlignedLoad(buffer_ptr.type, buffer_ptr.addr, llvm::Align(alignment), is_volatile); @@ -1279,7 +1325,7 @@ llvm::Value* CodeGenLLVM::VisitExpr_(const LoadNode* op) { #else llvm::LoadInst* load = builder_->CreateAlignedLoad(buffer_ptr.addr, alignment, is_volatile); #endif - AddAliasInfo(load, op->buffer_var.get(), op->index); + AddAliasInfo(load, buffer_var.get(), buffer_index); return load; } } @@ -1288,7 +1334,8 @@ llvm::Value* CodeGenLLVM::VisitExpr_(const LoadNode* op) { int basic_align = t.bits() / 8; llvm::Value* ret = llvm::UndefValue::get(DTypeToLLVMType(t)); auto f = [&](int i, llvm::Value* index) { - TypedPointer buffer_ptr = CreateBufferPtr(t.element_of(), buffer, index); + TypedPointer buffer_ptr = + CreateBufferPtr(MakeValue(op->buffer->data), op->buffer->dtype, index, t.element_of()); #if TVM_LLVM_VERSION >= 110 llvm::LoadInst* load = builder_->CreateAlignedLoad(buffer_ptr.type, buffer_ptr.addr, llvm::Align(basic_align), is_volatile); @@ -1299,9 +1346,9 @@ llvm::Value* CodeGenLLVM::VisitExpr_(const LoadNode* op) { llvm::LoadInst* load = builder_->CreateAlignedLoad(buffer_ptr.addr, basic_align, is_volatile); #endif ret = builder_->CreateInsertElement(ret, load, ConstInt32(i)); - AddAliasInfo(load, op->buffer_var.get(), PrimExpr()); + AddAliasInfo(load, buffer_var.get(), PrimExpr()); }; - this->Scalarize(op->index, f); + this->Scalarize(buffer_index, f); return ret; } @@ -1366,17 +1413,34 @@ llvm::Value* CodeGenLLVM::VisitExpr_(const BroadcastNode* op) { } void CodeGenLLVM::VisitStmt_(const StoreNode* op) { - ICHECK(is_one(op->predicate)) << op->predicate; - DataType t = op->value.dtype(); - bool is_volatile = volatile_buf_.count(op->buffer_var.get()); - llvm::Value* buffer = MakeValue(op->buffer_var); - llvm::Value* index = MakeValue(op->index); + LOG(FATAL) << "Unexpected deprecated StoreNode. Use BufferStoreNode instead."; +} + +void CodeGenLLVM::VisitStmt_(const BufferStoreNode* op) { + ICHECK_EQ(op->indices.size(), 1) << "CodeGenLLVM expects flattened 1-d buffers."; + + DataType value_dtype = op->value.dtype(); + DataType buffer_element_dtype = op->buffer->dtype; + Var buffer_var = op->buffer->data; + PrimExpr buffer_index = op->indices[0]; + + bool is_volatile = volatile_buf_.count(buffer_var.get()); + llvm::Value* buffer = MakeValue(buffer_var); llvm::Value* value = MakeValue(op->value); - if (t.lanes() == 1) { + if (value_dtype.lanes() == buffer_element_dtype.lanes()) { int alignment, native_bits; - GetAlignment(t, op->buffer_var.get(), op->index, &alignment, &native_bits); - TypedPointer buffer_ptr = CreateBufferPtr(t, buffer, index); + GetAlignment(value_dtype, buffer_var.get(), buffer_index, &alignment, &native_bits); + + TypedPointer buffer_ptr; + if (HasAlignmentPadding(buffer_element_dtype)) { + buffer_ptr = + CreateBufferPtr(MakeValue(op->buffer->data), buffer_element_dtype.element_of(), + MakeValue(buffer_element_dtype.lanes() * buffer_index), value_dtype); + } else { + buffer_ptr = CreateBufferPtr(MakeValue(op->buffer->data), buffer_element_dtype, + MakeValue(buffer_index), value_dtype); + } #if TVM_LLVM_VERSION >= 110 llvm::StoreInst* store = builder_->CreateAlignedStore(value, buffer_ptr.addr, llvm::Align(alignment), is_volatile); @@ -1384,20 +1448,21 @@ void CodeGenLLVM::VisitStmt_(const StoreNode* op) { llvm::StoreInst* store = builder_->CreateAlignedStore(value, buffer_ptr.addr, alignment, is_volatile); #endif - AddAliasInfo(store, op->buffer_var.get(), op->index); + AddAliasInfo(store, buffer_var.get(), buffer_index); return; } else { // vector store - if (const RampNode* ramp = op->index.as()) { + if (const RampNode* ramp = buffer_index.as()) { if (is_one(ramp->stride)) { int alignment, native_bits; - GetAlignment(t, op->buffer_var.get(), ramp->base, &alignment, &native_bits); - ICHECK_EQ(ramp->lanes, t.lanes()); + GetAlignment(value_dtype, buffer_var.get(), ramp->base, &alignment, &native_bits); + ICHECK_EQ(ramp->lanes * buffer_element_dtype.lanes(), value_dtype.lanes()); // The index argument is element-based, to create buffer pointer for t's element type. - TypedPointer buffer_ptr = CreateBufferPtr(t.element_of(), buffer, MakeValue(ramp->base)); + TypedPointer buffer_ptr = CreateBufferPtr(MakeValue(op->buffer->data), buffer_element_dtype, + MakeValue(ramp->base), value_dtype); unsigned addrspace = llvm::dyn_cast(buffer->getType())->getAddressSpace(); - buffer_ptr.type = DTypeToLLVMType(t); + buffer_ptr.type = DTypeToLLVMType(value_dtype); buffer_ptr.addr = builder_->CreatePointerCast(buffer_ptr.addr, buffer_ptr.type->getPointerTo(addrspace)); #if TVM_LLVM_VERSION >= 110 @@ -1407,16 +1472,17 @@ void CodeGenLLVM::VisitStmt_(const StoreNode* op) { llvm::StoreInst* store = builder_->CreateAlignedStore(value, buffer_ptr.addr, alignment, is_volatile); #endif - AddAliasInfo(store, op->buffer_var.get(), op->index); + AddAliasInfo(store, buffer_var.get(), buffer_index); return; } } } - ICHECK_GE(t.bits(), 8); + ICHECK_GE(value_dtype.bits(), 8); // scalarized store. - int basic_align = t.bits() / 8; + int basic_align = value_dtype.bits() / 8; auto f = [&](int i, llvm::Value* index) { - TypedPointer buffer_ptr = CreateBufferPtr(t.element_of(), buffer, index); + TypedPointer buffer_ptr = CreateBufferPtr(MakeValue(op->buffer->data), buffer_element_dtype, + index, value_dtype.element_of()); #if TVM_LLVM_VERSION >= 110 llvm::StoreInst* store = builder_->CreateAlignedStore(builder_->CreateExtractElement(value, i), buffer_ptr.addr, @@ -1425,9 +1491,9 @@ void CodeGenLLVM::VisitStmt_(const StoreNode* op) { llvm::StoreInst* store = builder_->CreateAlignedStore( builder_->CreateExtractElement(value, i), buffer_ptr.addr, basic_align, is_volatile); #endif - AddAliasInfo(store, op->buffer_var.get(), PrimExpr()); + AddAliasInfo(store, buffer_var.get(), PrimExpr()); }; - this->Scalarize(op->index, f); + this->Scalarize(buffer_index, f); } void CodeGenLLVM::VisitStmt_(const ForNode* op) { diff --git a/src/target/llvm/codegen_llvm.h b/src/target/llvm/codegen_llvm.h index 5431e92e0a10..e8cbe7ae445f 100644 --- a/src/target/llvm/codegen_llvm.h +++ b/src/target/llvm/codegen_llvm.h @@ -171,12 +171,14 @@ class CodeGenLLVM : public ExprFunctor, llvm::Value* VisitExpr_(const SelectNode* op) override; llvm::Value* VisitExpr_(const LetNode* op) override; llvm::Value* VisitExpr_(const LoadNode* op) override; + llvm::Value* VisitExpr_(const BufferLoadNode* op) override; llvm::Value* VisitExpr_(const CallNode* op) override; llvm::Value* VisitExpr_(const RampNode* op) override; llvm::Value* VisitExpr_(const ShuffleNode* op) override; llvm::Value* VisitExpr_(const BroadcastNode* op) override; // stmt void VisitStmt_(const StoreNode* op) override; + void VisitStmt_(const BufferStoreNode* op) override; void VisitStmt_(const ForNode* op) override; void VisitStmt_(const WhileNode* op) override; void VisitStmt_(const IfThenElseNode* op) override; @@ -319,6 +321,8 @@ class CodeGenLLVM : public ExprFunctor, // Get alignment given index. void GetAlignment(DataType t, const VarNode* buf_var, const PrimExpr& index, int* p_alignment, int* p_native_bits); + // Returns whether the LLVM type has padding for alignment + bool HasAlignmentPadding(DataType dtype); // Get constant string llvm::Constant* GetConstString(const std::string& str); // do a scalarize call with f @@ -338,7 +342,8 @@ class CodeGenLLVM : public ExprFunctor, llvm::Value* CreateSub(DataType t, llvm::Value* a, llvm::Value* b); llvm::Value* CreateMul(DataType t, llvm::Value* a, llvm::Value* b); llvm::Value* CreateBroadcast(llvm::Value* value, int lanes); - TypedPointer CreateBufferPtr(DataType t, llvm::Value* buffer, llvm::Value* index); + TypedPointer CreateBufferPtr(llvm::Value* buffer_ptr, DataType buffer_element_dtype, + llvm::Value* index, DataType value_dtype); // Vector concatenation. llvm::Value* CreateVecSlice(llvm::Value* vec, int begin, int extent); llvm::Value* CreateVecFlip(llvm::Value* vec); diff --git a/src/target/source/codegen_c.cc b/src/target/source/codegen_c.cc index 01c1c911b7de..1752c2a2e826 100644 --- a/src/target/source/codegen_c.cc +++ b/src/target/source/codegen_c.cc @@ -159,78 +159,58 @@ void CodeGenC::PrintSSAAssign(const std::string& target, const std::string& src, } // Print a reference expression to a buffer. -std::string CodeGenC::GetBufferRef(DataType t, const VarNode* buffer, PrimExpr index) { +std::string CodeGenC::GetBufferRef(DataType t, const BufferNode* buffer, PrimExpr index) { + const VarNode* buffer_var = buffer->data.get(); std::ostringstream os; - std::string vid = GetVarID(buffer); + std::string vid = GetVarID(buffer_var); std::string scope; - if (alloc_storage_scope_.count(buffer)) { - scope = alloc_storage_scope_.at(buffer); + if (alloc_storage_scope_.count(buffer_var)) { + scope = alloc_storage_scope_.at(buffer_var); } - bool is_vol = IsVolatile(buffer); - if (t.lanes() == 1) { - if (!HandleTypeMatch(buffer, t) || is_vol) { - os << "(("; - if (is_vol) { - os << "volatile "; - } - // Scope may not be part of type. - if (!scope.empty() && IsScopePartOfType()) { - PrintStorageScope(scope, os); - } - PrintType(t, os); - os << "*)" << vid << ')'; - } else { - os << vid; - } - os << "[("; - PrintExpr(index, os); - os << ")"; - if (t.bits() == 4 || (t.bits() == 1 && t.is_int())) { - os << " / " << (32 / t.bits()); - } - os << ']'; - } else { - // Buffer declared as vector type. - // optimize for case where it is in register, - if (HandleTypeMatch(buffer, t) && !is_vol) { - // optimize for constant access - if (auto* ptr = index.as()) { - int64_t offset = ptr->value; - ICHECK_EQ(offset % t.lanes(), 0) << "Find unaligned vector load to a vector type"; - os << vid << '[' << (offset / t.lanes()) << ']'; - return os.str(); - } - } - os << "(("; + bool is_vol = IsVolatile(buffer_var); + + auto ptr_cast = [this, is_vol, scope](DataType pointed_to) { + std::ostringstream ptr_os; + ptr_os << "("; if (is_vol) { - os << "volatile "; + ptr_os << "volatile "; } if (!scope.empty() && IsScopePartOfType()) { - PrintStorageScope(scope, os); - } - PrintType(t, os); - os << "*)("; - if (!HandleTypeMatch(buffer, t.element_of())) { - os << '('; - if (!scope.empty() && IsScopePartOfType()) { - PrintStorageScope(scope, os); - } - PrintType(t.element_of(), os); - os << "*)"; - } - if (t.bits() == 4 || (t.bits() == 1 && t.is_int())) { - os << vid << ") + ("; - PrintExpr(index, os); - os << ")"; - os << " / " << t.lanes(); - os << ")[0]"; - } else { - os << vid << " + ("; - PrintExpr(index, os); - os << ")"; - os << "))[0]"; + PrintStorageScope(scope, ptr_os); } + PrintType(pointed_to, ptr_os); + ptr_os << "*)"; + return ptr_os.str(); + }; + + DataType buffer_element_dtype = buffer->dtype; + + std::string buffer_str = vid; + if (!HandleTypeMatch(buffer_var, buffer_element_dtype) || is_vol) { + std::stringstream temp; + temp << "(" << ptr_cast(buffer_element_dtype) << vid << ")"; + buffer_str = temp.str(); + } + + std::string index_str = PrintExpr(index); + if (t.bits() == 4 || (t.bits() == 1 && t.is_int())) { + // This is a special case, because CodegenCUDA::PrintType() + // returns "int" for bool and for 4-bit integers. In most cases, + // we divide by the number of lanes to determine the index. + // However, the backing type for scalar int4 and scalar bool is + // int32. Therefore, we need to divide by the ratio of their + // sizes in that case. + int div_factor = (t.lanes() == 1) ? (32 / t.bits()) : t.lanes(); + + os << "*(" + << "(" << ptr_cast(t) << vid << ")" + << " + " << index_str << " / " << div_factor << ")"; + } else if (t == buffer_element_dtype) { + os << buffer_str << "[" << index_str << "]"; + } else { + os << "*" << ptr_cast(t) << "(" << buffer_str << " + " << index_str << ")"; } + return os.str(); } @@ -334,11 +314,11 @@ void CodeGenC::PrintVecElemStore(const std::string& vec, DataType t, int i, stream << vec << ".s" << std::hex << i << " = " << value << ";\n" << std::dec; } -std::string CodeGenC::GetVecLoad(DataType t, const VarNode* buffer, PrimExpr base) { +std::string CodeGenC::GetVecLoad(DataType t, const BufferNode* buffer, PrimExpr base) { return GetBufferRef(t, buffer, base); } -void CodeGenC::PrintVecStore(const VarNode* buffer, DataType t, PrimExpr base, +void CodeGenC::PrintVecStore(const BufferNode* buffer, DataType t, PrimExpr base, const std::string& value) { std::string ref = GetBufferRef(t, buffer, base); this->PrintIndent(); @@ -586,17 +566,10 @@ void CodeGenC::VisitExpr_(const CallNode* op, std::ostream& os) { // NOLINT(*) PrintExpr(op->args[2], os); os << ")"; } else if (op->op.same_as(builtin::address_of())) { - const LoadNode* l = op->args[0].as(); - ICHECK(op->args.size() == 1 && l); - os << "(("; - this->PrintType(l->dtype.element_of(), os); - os << " *)" << this->GetVarID(l->buffer_var.get()) << " + " - << "("; - this->PrintExpr(l->index, os); - if (l->dtype.bits() == 4 || (l->dtype.bits() == 1 && l->dtype.is_int())) { - os << " / " << (32 / l->dtype.bits()); - } - os << "))"; + const BufferLoadNode* load = op->args[0].as(); + ICHECK(op->args.size() == 1 && load); + ICHECK_EQ(load->indices.size(), 1) << "CodeGenC only supports flat memory allocations."; + os << "(&(" << GetBufferRef(load->dtype, load->buffer.get(), load->indices[0]) << "))"; } else if (op->op.same_as(builtin::tvm_struct_get())) { ICHECK_EQ(op->args.size(), 3U); os << GetStructRef(op->dtype, op->args[0], op->args[1], op->args[2].as()->value); @@ -681,18 +654,27 @@ void CodeGenC::VisitStmt_(const AllocateConstNode* op) { } void CodeGenC::VisitExpr_(const LoadNode* op, std::ostream& os) { // NOLINT(*) + LOG(FATAL) << "Unexpected deprecated LoadNode. Use BufferLoadNode instead."; +} + +void CodeGenC::VisitExpr_(const BufferLoadNode* op, std::ostream& os) { // NOLINT(*) + ICHECK_EQ(op->indices.size(), 1) << "Load from non-flat memory not supported."; + + DataType value_dtype = op->dtype; + PrimExpr index = op->indices[0]; + Var buffer_var = op->buffer->data; + DataType element_dtype = op->buffer->dtype; + int lanes = op->dtype.lanes(); // delcare type. - if (op->dtype.lanes() == 1) { - std::string ref = GetBufferRef(op->dtype, op->buffer_var.get(), op->index); + if (value_dtype.lanes() == element_dtype.lanes()) { + std::string ref = GetBufferRef(op->dtype, op->buffer.get(), index); HandleVolatileLoads(ref, op, os); } else { - ICHECK(is_one(op->predicate)) << "predicated load is not supported"; - bool can_vector_load = false; arith::PVar base; - if (arith::ramp(base, 1, op->dtype.lanes()).Match(op->index)) { - const RampNode* ramp = op->index.as(); + if (arith::ramp(base, 1, op->dtype.lanes()).Match(index)) { + const RampNode* ramp = index.as(); ICHECK(ramp); arith::ModularSet me = arith::Analyzer().modular_set(ramp->base); // The condition: {k * coeff + base} divisible by the alignment for any k @@ -702,19 +684,19 @@ void CodeGenC::VisitExpr_(const LoadNode* op, std::ostream& os) { // NOLINT(*) } if (can_vector_load) { - std::string ref = GetVecLoad(op->dtype, op->buffer_var.get(), base.Eval()); + std::string ref = GetVecLoad(op->dtype, op->buffer.get(), base.Eval()); HandleVolatileLoads(ref, op, os); } else { std::ostringstream svalue_expr; - std::string sindex = SSAGetID(PrintExpr(op->index), op->index.dtype()); - std::string vid = GetVarID(op->buffer_var.get()); + std::string sindex = SSAGetID(PrintExpr(index), index.dtype()); + std::string vid = GetVarID(buffer_var.get()); DataType elem_type = op->dtype.element_of(); for (int i = 0; i < lanes; ++i) { std::ostringstream value_temp; - if (!HandleTypeMatch(op->buffer_var.get(), elem_type)) { + if (!HandleTypeMatch(buffer_var.get(), elem_type)) { value_temp << "(("; - if (op->buffer_var.get()->dtype.is_handle()) { - auto it = alloc_storage_scope_.find(op->buffer_var.get()); + if (buffer_var.get()->dtype.is_handle()) { + auto it = alloc_storage_scope_.find(buffer_var.get()); if (it != alloc_storage_scope_.end()) { PrintStorageScope(it->second, value_temp); } @@ -725,7 +707,7 @@ void CodeGenC::VisitExpr_(const LoadNode* op, std::ostream& os) { // NOLINT(*) value_temp << vid; } value_temp << '['; - PrintVecElemLoad(sindex, op->index.dtype(), i, value_temp); + PrintVecElemLoad(sindex, index.dtype(), i, value_temp); value_temp << ']'; PrintVecElemLoadExpr(op->dtype, i, value_temp.str(), svalue_expr); } @@ -735,35 +717,44 @@ void CodeGenC::VisitExpr_(const LoadNode* op, std::ostream& os) { // NOLINT(*) } void CodeGenC::VisitStmt_(const StoreNode* op) { - DataType t = op->value.dtype(); - if (t.lanes() == 1) { + LOG(FATAL) << "Unexpected deprecated StoreNode. Use BufferStoreNode instead."; +} + +void CodeGenC::VisitStmt_(const BufferStoreNode* op) { + ICHECK_EQ(op->indices.size(), 1) << "Store to non-flat memory not supported."; + + DataType value_dtype = op->value.dtype(); + DataType element_dtype = op->buffer->dtype; + PrimExpr index_expr = op->indices[0]; + Var buffer_var = op->buffer->data; + + if (value_dtype.lanes() == element_dtype.lanes()) { std::string value = this->PrintExpr(op->value); - std::string ref = this->GetBufferRef(t, op->buffer_var.get(), op->index); + std::string ref = this->GetBufferRef(value_dtype, op->buffer.get(), index_expr); this->PrintIndent(); stream << ref << " = " << value << ";\n"; } else { - ICHECK(is_one(op->predicate)) << "Predicated store is not supported"; arith::PVar base; - if (arith::ramp(base, 1, t.lanes()).Match(op->index)) { + if (arith::ramp(base, 1, value_dtype.lanes()).Match(index_expr)) { std::string value = this->PrintExpr(op->value); - this->PrintVecStore(op->buffer_var.get(), t, base.Eval(), value); + this->PrintVecStore(op->buffer.get(), value_dtype, base.Eval(), value); } else { // The assignment below introduces side-effect, and the resulting value cannot // be reused across multiple expression, thus a new scope is needed int vec_scope = BeginScope(); // store elements seperately - std::string index = SSAGetID(PrintExpr(op->index), op->index.dtype()); + std::string index = SSAGetID(PrintExpr(index_expr), index_expr.dtype()); std::string value = SSAGetID(PrintExpr(op->value), op->value.dtype()); - std::string vid = GetVarID(op->buffer_var.get()); - for (int i = 0; i < t.lanes(); ++i) { + std::string vid = GetVarID(buffer_var.get()); + for (int i = 0; i < value_dtype.lanes(); ++i) { this->PrintIndent(); - DataType elem_type = t.element_of(); - if (!HandleTypeMatch(op->buffer_var.get(), elem_type)) { + DataType elem_type = value_dtype.element_of(); + if (!HandleTypeMatch(buffer_var.get(), elem_type)) { stream << "(("; - if (op->buffer_var.get()->dtype.is_handle()) { - auto it = alloc_storage_scope_.find(op->buffer_var.get()); + if (buffer_var.get()->dtype.is_handle()) { + auto it = alloc_storage_scope_.find(buffer_var.get()); if (it != alloc_storage_scope_.end()) { PrintStorageScope(it->second, stream); } @@ -774,7 +765,7 @@ void CodeGenC::VisitStmt_(const StoreNode* op) { stream << vid; } stream << '['; - PrintVecElemLoad(index, op->index.dtype(), i, stream); + PrintVecElemLoad(index, index_expr.dtype(), i, stream); stream << "] = "; PrintVecElemLoad(value, op->value.dtype(), i, stream); stream << ";\n"; diff --git a/src/target/source/codegen_c.h b/src/target/source/codegen_c.h index 2af77bb28b53..4f671950260e 100644 --- a/src/target/source/codegen_c.h +++ b/src/target/source/codegen_c.h @@ -126,6 +126,7 @@ class CodeGenC : public ExprFunctor, // expression void VisitExpr_(const VarNode* op, std::ostream& os) override; // NOLINT(*) void VisitExpr_(const LoadNode* op, std::ostream& os) override; // NOLINT(*) + void VisitExpr_(const BufferLoadNode* op, std::ostream& os) override; // NOLINT(*) void VisitExpr_(const LetNode* op, std::ostream& os) override; // NOLINT(*) void VisitExpr_(const CallNode* op, std::ostream& os) override; // NOLINT(*) void VisitExpr_(const AddNode* op, std::ostream& os) override; // NOLINT(*) @@ -155,6 +156,7 @@ class CodeGenC : public ExprFunctor, // statment void VisitStmt_(const LetStmtNode* op) override; void VisitStmt_(const StoreNode* op) override; + void VisitStmt_(const BufferStoreNode* op) override; void VisitStmt_(const ForNode* op) override; void VisitStmt_(const WhileNode* op) override; void VisitStmt_(const IfThenElseNode* op) override; @@ -176,9 +178,9 @@ class CodeGenC : public ExprFunctor, virtual void PrintVecBinaryOp(const std::string& op, DataType op_type, PrimExpr lhs, PrimExpr rhs, std::ostream& os); // NOLINT(*) // print vector load - virtual std::string GetVecLoad(DataType t, const VarNode* buffer, PrimExpr base); + virtual std::string GetVecLoad(DataType t, const BufferNode* buffer, PrimExpr base); // print vector store - virtual void PrintVecStore(const VarNode* buffer, DataType t, PrimExpr base, + virtual void PrintVecStore(const BufferNode* buffer, DataType t, PrimExpr base, const std::string& value); // NOLINT(*) // print load of single element virtual void PrintVecElemLoad(const std::string& vec, DataType t, int i, @@ -201,7 +203,7 @@ class CodeGenC : public ExprFunctor, // Print reference to struct location std::string GetStructRef(DataType t, const PrimExpr& buffer, const PrimExpr& index, int kind); // Print reference to a buffer as type t in index. - virtual std::string GetBufferRef(DataType t, const VarNode* buffer, PrimExpr index); + virtual std::string GetBufferRef(DataType t, const BufferNode* buffer, PrimExpr index); /*! * \brief Handle volatile loads. @@ -211,7 +213,8 @@ class CodeGenC : public ExprFunctor, * does not implement volatile member functions. CUDA codegen will cast * away volatile qualifier from CUDA __half types. */ - virtual void HandleVolatileLoads(const std::string& value, const LoadNode* op, std::ostream& os) { + virtual void HandleVolatileLoads(const std::string& value, const BufferLoadNode* op, + std::ostream& os) { // By default, do nothing but print the loaded value. os << value; } diff --git a/src/target/source/codegen_c_host.cc b/src/target/source/codegen_c_host.cc index 7ddea46c07bd..db23c0152865 100644 --- a/src/target/source/codegen_c_host.cc +++ b/src/target/source/codegen_c_host.cc @@ -97,6 +97,10 @@ void CodeGenCHost::PrintType(DataType t, std::ostream& os) { // NOLINT(*) os << "void*"; return; } + if (t.is_void()) { + os << "void"; + return; + } if (t == DataType::Bool()) { os << "bool"; return; diff --git a/src/target/source/codegen_cuda.cc b/src/target/source/codegen_cuda.cc index 984f8a13351e..0dda079066d9 100644 --- a/src/target/source/codegen_cuda.cc +++ b/src/target/source/codegen_cuda.cc @@ -171,6 +171,12 @@ void CodeGenCUDA::PrintType(DataType t, std::ostream& os) { // NOLINT(*) os << "void*"; return; } + + if (t.is_void()) { + os << "void"; + return; + } + bool fail = false; if (t.is_float()) { switch (t.bits()) { @@ -1115,12 +1121,12 @@ int32_t CodeGenCUDA::GetWmmaFragmentSize(const std::string& scope, const VarNode return 0; } -void CodeGenCUDA::HandleVolatileLoads(const std::string& value, const LoadNode* op, +void CodeGenCUDA::HandleVolatileLoads(const std::string& value, const BufferLoadNode* op, std::ostream& os) { // Cast away volatile qualifier for fp16 types. That is, only loads and // stores are volatile. The loaded objects are not marked as volatile. // - if ((op->dtype.is_float16() || op->dtype.is_bfloat16()) && IsVolatile(op->buffer_var.get())) { + if ((op->dtype.is_float16() || op->dtype.is_bfloat16()) && IsVolatile(op->buffer->data.get())) { os << "("; PrintType(op->dtype, os); os << ")(" << value << ")"; diff --git a/src/target/source/codegen_cuda.h b/src/target/source/codegen_cuda.h index 385b7343c8fd..673753c470ae 100644 --- a/src/target/source/codegen_cuda.h +++ b/src/target/source/codegen_cuda.h @@ -76,7 +76,8 @@ class CodeGenCUDA final : public CodeGenC { private: // Handle volatile loads - void HandleVolatileLoads(const std::string& value, const LoadNode* op, std::ostream& os) final; + void HandleVolatileLoads(const std::string& value, const BufferLoadNode* op, + std::ostream& os) final; // Whether scope such as "__shared__" or "__constant__" is part of type. bool IsScopePartOfType() const final { return false; } diff --git a/src/target/source/codegen_metal.cc b/src/target/source/codegen_metal.cc index b44afec57d5d..a76da36ea725 100644 --- a/src/target/source/codegen_metal.cc +++ b/src/target/source/codegen_metal.cc @@ -177,6 +177,11 @@ void CodeGenMetal::PrintType(DataType t, std::ostream& os) { // NOLINT(*) os << "void*"; return; } + + if (t.is_void()) { + os << "void"; + return; + } if (t == DataType::Bool()) { os << "bool"; return; diff --git a/src/target/source/codegen_opencl.cc b/src/target/source/codegen_opencl.cc index a9cd9d8ae930..a0e19ca35cd9 100644 --- a/src/target/source/codegen_opencl.cc +++ b/src/target/source/codegen_opencl.cc @@ -174,6 +174,10 @@ void CodeGenOpenCL::PrintType(DataType t, std::ostream& os) { // NOLINT(*) os << "void*"; return; } + if (t.is_void()) { + os << "void"; + return; + } if (t == DataType::Bool()) { os << "bool"; return; @@ -256,21 +260,22 @@ void CodeGenOpenCL::PrintType(const Type& type, std::ostream& os) { // NOLINT(* } } -void CodeGenOpenCL::PrintVecAddr(const VarNode* buffer, DataType t, PrimExpr base, +void CodeGenOpenCL::PrintVecAddr(const BufferNode* buffer, DataType t, PrimExpr base, std::ostream& os) { // NOLINT(*) - if (!HandleTypeMatch(buffer, t.element_of())) { + const VarNode* buffer_var = buffer->data.get(); + if (!HandleTypeMatch(buffer_var, t.element_of())) { os << '('; - auto it = alloc_storage_scope_.find(buffer); + auto it = alloc_storage_scope_.find(buffer_var); if (it != alloc_storage_scope_.end()) { PrintStorageScope(it->second, os); } PrintType(t.element_of(), os); os << "*)"; } - os << GetVarID(buffer) << " + "; + os << GetVarID(buffer_var) << " + "; PrintExpr(base, os); } -std::string CodeGenOpenCL::GetVecLoad(DataType t, const VarNode* buffer, PrimExpr base) { +std::string CodeGenOpenCL::GetVecLoad(DataType t, const BufferNode* buffer, PrimExpr base) { std::ostringstream os; os << "vload" << t.lanes() << "(0, "; PrintVecAddr(buffer, t, base, os); @@ -278,7 +283,7 @@ std::string CodeGenOpenCL::GetVecLoad(DataType t, const VarNode* buffer, PrimExp return os.str(); } -void CodeGenOpenCL::PrintVecStore(const VarNode* buffer, DataType t, PrimExpr base, +void CodeGenOpenCL::PrintVecStore(const BufferNode* buffer, DataType t, PrimExpr base, const std::string& value) { this->PrintIndent(); stream << "vstore" << t.lanes() << "(" << value << ", 0, "; @@ -337,13 +342,17 @@ std::string CodeGenOpenCL::CastFromTo(std::string value, DataType from, DataType } void CodeGenOpenCL::VisitStmt_(const StoreNode* op) { + LOG(FATAL) << "Unexpected use of deprecated StoreNode. Please use BufferStoreNode instead."; +} + +void CodeGenOpenCL::VisitStmt_(const BufferStoreNode* op) { if (auto call = op->value.as()) { if (call->op.same_as(builtin::texture2d_load())) { need_texture_ssa_ = false; // If storing a texture load into a buffer, don't use an // intermediate local unless the buffer allocation is a // single element selected from the texture read. - auto it = allocation_size_.find(op->buffer_var.get()); + auto it = allocation_size_.find(op->buffer->data.get()); if (it != allocation_size_.end() && it->second == 1) { need_texture_ssa_ = true; } @@ -371,16 +380,17 @@ void CodeGenOpenCL::VisitStmt_(const AllocateNode* op) { void CodeGenOpenCL::VisitExpr_(const CallNode* op, std::ostream& os) { if (op->op.same_as(builtin::address_of())) { // Overload tvm_address_of to add storage scope (e.g. __global). - const LoadNode* load = op->args[0].as(); + const BufferLoadNode* load = op->args[0].as(); ICHECK(op->args.size() == 1 && load); + ICHECK_EQ(load->indices.size(), 0) << "CodeGenOpenCL only supports flat memory allocations."; os << "(("; - auto it = alloc_storage_scope_.find(load->buffer_var.get()); + auto it = alloc_storage_scope_.find(load->buffer->data.get()); if (it != alloc_storage_scope_.end()) { PrintStorageScope(it->second, os); } this->PrintType(load->dtype.element_of(), os); - os << " *)" << this->GetVarID(load->buffer_var.get()) << " + "; - this->PrintExpr(load->index, os); + os << " *)" << this->GetVarID(load->buffer->data.get()) << " + "; + this->PrintExpr(load->indices[0], os); os << ')'; } else if (op->op.same_as(builtin::texture2d_store())) { auto* ptr_type = op->args[0].as()->type_annotation.as(); diff --git a/src/target/source/codegen_opencl.h b/src/target/source/codegen_opencl.h index 2670c601c43c..3508eef43185 100644 --- a/src/target/source/codegen_opencl.h +++ b/src/target/source/codegen_opencl.h @@ -47,11 +47,11 @@ class CodeGenOpenCL final : public CodeGenC { void PrintStorageSync(const CallNode* op) final; // NOLINT(*) void PrintType(DataType t, std::ostream& os) final; // NOLINT(*) void PrintType(const Type& type, std::ostream& os) final; // NOLINT(*) - std::string GetVecLoad(DataType t, const VarNode* buffer, PrimExpr base) final; - void PrintVecStore(const VarNode* buffer, DataType t, PrimExpr base, + std::string GetVecLoad(DataType t, const BufferNode* buffer, PrimExpr base) final; + void PrintVecStore(const BufferNode* buffer, DataType t, PrimExpr base, const std::string& value) final; // NOLINT(*) // the address of load/store - void PrintVecAddr(const VarNode* buffer, DataType t, PrimExpr base, + void PrintVecAddr(const BufferNode* buffer, DataType t, PrimExpr base, std::ostream& os); // NOLINT(*) void PrintRestrict(const Var& v, std::ostream& os) final; // NOLINT(*) std::string CastFromTo(std::string value, DataType from, DataType target); // NOLINT(*) @@ -64,6 +64,7 @@ class CodeGenOpenCL final : public CodeGenC { void VisitExpr_(const CastNode* op, std::ostream& os) final; // NOLINT(*) void VisitExpr_(const FloatImmNode* op, std::ostream& os) final; // NOLINT(*) void VisitStmt_(const StoreNode* op) final; // NOLINT(*) + void VisitStmt_(const BufferStoreNode* op) final; // NOLINT(*) // overload min and max to avoid ambiguous call errors void VisitExpr_(const MinNode* op, std::ostream& os) final; diff --git a/src/target/source/codegen_source_base.cc b/src/target/source/codegen_source_base.cc index 5dcf1587bdb9..5acb42071b62 100644 --- a/src/target/source/codegen_source_base.cc +++ b/src/target/source/codegen_source_base.cc @@ -119,6 +119,10 @@ void CodeGenSourceBase::PrintType(DataType type, std::ostream& os) { // NOLINT( os << "void*"; return; } + if (type.is_void()) { + os << "void"; + return; + } if (type.is_float()) { if (type.bits() == 32) { os << "float"; diff --git a/src/target/spirv/codegen_spirv.cc b/src/target/spirv/codegen_spirv.cc index 1d30b9bfd63a..0427d8cd5853 100644 --- a/src/target/spirv/codegen_spirv.cc +++ b/src/target/spirv/codegen_spirv.cc @@ -412,22 +412,23 @@ spirv::Value CodeGenSPIRV::VisitExpr_(const BroadcastNode* op) { return builder_->Concat(values); } -spirv::Value CodeGenSPIRV::VisitExpr_(const LoadNode* op) { - ICHECK(is_one(op->predicate)); +spirv::Value CodeGenSPIRV::VisitExpr_(const BufferLoadNode* op) { + ICHECK_EQ(op->indices.size(), 1) << "SPIR-V codegen expects flat memory buffers"; + Var buffer_var = op->buffer->data; + PrimExpr prim_index = op->indices[0]; DataType desired_read_type = op->dtype; if (desired_read_type == DataType::Bool()) { desired_read_type = boolean_storage_type_.with_lanes(desired_read_type.lanes()); } - const VarNode* buffer_var = op->buffer_var.get(); - auto it = storage_info_.find(buffer_var); + auto it = storage_info_.find(buffer_var.get()); ICHECK(it != storage_info_.end()); StorageInfo& info = it->second; - info.CheckContentType(desired_read_type, op->index.dtype().lanes()); + info.CheckContentType(desired_read_type, prim_index.dtype().lanes()); spirv::SType content_type = builder_->GetSType(info.element_type); - spirv::Value buffer = MakeValue(op->buffer_var); + spirv::Value buffer = MakeValue(buffer_var); spirv::SType ptr_type = builder_->GetPointerType(content_type, buffer.stype.storage_class); uint32_t mask = spv::MemoryAccessMaskNone; @@ -438,7 +439,7 @@ spirv::Value CodeGenSPIRV::VisitExpr_(const LoadNode* op) { if (desired_read_type == info.element_type) { // Requested a single value from an array. This may be a scalar load // or a vectorized load, based on the array element type. - spirv::Value index = MakeValue(op->index); + spirv::Value index = MakeValue(prim_index); spirv::Value ptr = builder_->StructArrayAccess(ptr_type, buffer, index); spirv::Value loaded = builder_->MakeValue(spv::OpLoad, content_type, ptr, mask); // OpTypeBool have no physical address/storage. Here, cast from @@ -457,13 +458,13 @@ spirv::Value CodeGenSPIRV::VisitExpr_(const LoadNode* op) { spirv::Value ptr = builder_->StructArrayAccess(ptr_type, buffer, index); values.emplace_back(builder_->MakeValue(spv::OpLoad, content_type, ptr, mask)); }; - this->Scalarize(op->index, f); + this->Scalarize(prim_index, f); return builder_->Concat(values); } else { LOG(FATAL) << "Cannot perform buffer access of buffer variable '" << buffer_var->name_hint << "' with element type " << info.element_type << " using index of type " - << op->index->dtype << " to produce output of type " << op->dtype; + << prim_index->dtype << " to produce output of type " << op->dtype; return spirv::Value(); } } @@ -483,15 +484,18 @@ void CodeGenSPIRV::Scalarize(const PrimExpr& e, std::functionpredicate)); - auto it = storage_info_.find(op->buffer_var.get()); +void CodeGenSPIRV::VisitStmt_(const BufferStoreNode* op) { + ICHECK_EQ(op->indices.size(), 1) << "SPIR-V codegen expects flat memory buffers"; + Var buffer_var = op->buffer->data; + PrimExpr prim_index = op->indices[0]; + + auto it = storage_info_.find(buffer_var.get()); ICHECK(it != storage_info_.end()); StorageInfo& info = it->second; - info.CheckContentType(op->value.dtype(), op->index.dtype().lanes()); + info.CheckContentType(op->value.dtype(), prim_index.dtype().lanes()); spirv::SType content_type = builder_->GetSType(info.element_type); - spirv::Value buffer = MakeValue(op->buffer_var); + spirv::Value buffer = MakeValue(buffer_var); spirv::Value value = MakeValue(op->value); spirv::SType ptr_type = builder_->GetPointerType(content_type, buffer.stype.storage_class); @@ -505,7 +509,7 @@ void CodeGenSPIRV::VisitStmt_(const StoreNode* op) { // or a vectorized store, based on the array element type. ICHECK_EQ(info.element_type, op->value.dtype()) << "Vulkan only allow one type access to the same buffer"; - spirv::Value index = MakeValue(op->index); + spirv::Value index = MakeValue(prim_index); spirv::Value ptr = builder_->StructArrayAccess(ptr_type, buffer, index); builder_->MakeInst(spv::OpStore, ptr, value, mask); @@ -517,12 +521,12 @@ void CodeGenSPIRV::VisitStmt_(const StoreNode* op) { spirv::Value ptr = builder_->StructArrayAccess(ptr_type, buffer, index); builder_->MakeInst(spv::OpStore, ptr, elem, mask); }; - this->Scalarize(op->index, f); + this->Scalarize(prim_index, f); } else { LOG(FATAL) << "Cannot store value of type " << op->value.dtype() << " into buffer variable '" - << op->buffer_var->name_hint << "' with element type " << info.element_type - << " using index of type " << op->index->dtype; + << buffer_var->name_hint << "' with element type " << info.element_type + << " using index of type " << prim_index->dtype; } } diff --git a/src/target/spirv/codegen_spirv.h b/src/target/spirv/codegen_spirv.h index 74b62e7613d1..08b9db0ee539 100644 --- a/src/target/spirv/codegen_spirv.h +++ b/src/target/spirv/codegen_spirv.h @@ -100,9 +100,9 @@ class CodeGenSPIRV : public ExprFunctor, spirv::Value VisitExpr_(const CallNode* op) override; spirv::Value VisitExpr_(const RampNode* op) override; spirv::Value VisitExpr_(const BroadcastNode* op) override; - spirv::Value VisitExpr_(const LoadNode* op) override; + spirv::Value VisitExpr_(const BufferLoadNode* op) override; // stmt - void VisitStmt_(const StoreNode* op) override; + void VisitStmt_(const BufferStoreNode* op) override; void VisitStmt_(const ForNode* op) override; void VisitStmt_(const WhileNode* op) override; void VisitStmt_(const IfThenElseNode* op) override; diff --git a/src/target/stackvm/codegen_stackvm.cc b/src/target/stackvm/codegen_stackvm.cc index 402e3291975f..e70405445349 100644 --- a/src/target/stackvm/codegen_stackvm.cc +++ b/src/target/stackvm/codegen_stackvm.cc @@ -140,12 +140,21 @@ int CodeGenStackVM::GetVarID(const VarNode* v) const { } void CodeGenStackVM::VisitExpr_(const LoadNode* op) { - this->Push(op->buffer_var); + LOG(FATAL) << "Unexpected use of deprecated LoadNode. Please use BufferLoadNode instead."; +} + +void CodeGenStackVM::VisitExpr_(const BufferLoadNode* op) { + ICHECK_EQ(op->indices.size(), 1) << "StackVM expects flat 1-d buffers. " + << "Has StorageFlatten (TE-based schedules) or " + << "FlattenBuffer (TIR-based schedules) been run?"; + auto index = op->indices[0]; + + this->Push(op->buffer->data); StackVM::OpCode code = StackVM::GetLoad(op->dtype); - if (const IntImmNode* index = op->index.as()) { - this->PushOp(code, index->value); + if (const IntImmNode* int_index = index.as()) { + this->PushOp(code, int_index->value); } else { - this->Push(op->index); + this->Push(index); this->PushOp(StackVM::PUSH_I64, op->dtype.element_of().bytes()); this->PushOp(StackVM::MUL_I64); this->PushOp(StackVM::ADDR_ADD); @@ -154,13 +163,22 @@ void CodeGenStackVM::VisitExpr_(const LoadNode* op) { } void CodeGenStackVM::VisitStmt_(const StoreNode* op) { - this->Push(op->buffer_var); + LOG(FATAL) << "Unexpected use of deprecated StoreNode. Please use BufferStoreNode instead."; +} + +void CodeGenStackVM::VisitStmt_(const BufferStoreNode* op) { + ICHECK_EQ(op->indices.size(), 1) << "StackVM expects flat 1-d buffers. " + << "Has StorageFlatten (TE-based schedules) or " + << "FlattenBuffer (TIR-based schedules) been run?"; + auto index = op->indices[0]; + + this->Push(op->buffer->data); StackVM::OpCode code = StackVM::GetStore(op->value.dtype()); - if (const IntImmNode* index = op->index.as()) { + if (const IntImmNode* int_index = index.as()) { this->Push(op->value); - this->PushOp(code, index->value); + this->PushOp(code, int_index->value); } else { - this->Push(op->index); + this->Push(index); this->PushOp(StackVM::PUSH_I64, op->value.dtype().element_of().bytes()); this->PushOp(StackVM::MUL_I64); this->PushOp(StackVM::ADDR_ADD); @@ -175,11 +193,13 @@ void CodeGenStackVM::VisitStmt_(const AllocateNode* op) { void CodeGenStackVM::VisitExpr_(const CallNode* op) { if (op->op.same_as(builtin::address_of())) { - const LoadNode* l = op->args[0].as(); - ICHECK(op->args.size() == 1 && l); - this->PushOp(StackVM::LOAD_HEAP, GetVarID(l->buffer_var.get())); - this->Push(l->index); - this->PushOp(StackVM::PUSH_I64, l->dtype.element_of().bytes()); + const BufferLoadNode* load = op->args[0].as(); + ICHECK(op->args.size() == 1 && load); + ICHECK_EQ(load->indices.size(), 0) << "CodeGenStackVM only supports flat memory allocations."; + + this->PushOp(StackVM::LOAD_HEAP, GetVarID(load->buffer->data.get())); + this->Push(load->indices[0]); + this->PushOp(StackVM::PUSH_I64, load->dtype.element_of().bytes()); this->PushOp(StackVM::MUL_I64); this->PushOp(StackVM::ADDR_ADD); } else if (op->op.same_as(builtin::reinterpret())) { diff --git a/src/target/stackvm/codegen_stackvm.h b/src/target/stackvm/codegen_stackvm.h index 480ffc7eb870..ae6f316b475d 100644 --- a/src/target/stackvm/codegen_stackvm.h +++ b/src/target/stackvm/codegen_stackvm.h @@ -108,6 +108,7 @@ class CodeGenStackVM : public ExprFunctor, // expression void VisitExpr_(const VarNode* op) final; void VisitExpr_(const LoadNode* op) final; + void VisitExpr_(const BufferLoadNode* op) final; void VisitExpr_(const LetNode* op) final; void VisitExpr_(const CallNode* op) final; void VisitExpr_(const AddNode* op) final; @@ -136,6 +137,7 @@ class CodeGenStackVM : public ExprFunctor, // statment void VisitStmt_(const LetStmtNode* op) final; void VisitStmt_(const StoreNode* op) final; + void VisitStmt_(const BufferStoreNode* op) final; void VisitStmt_(const ForNode* op) final; void VisitStmt_(const IfThenElseNode* op) final; void VisitStmt_(const AllocateNode* op) final; diff --git a/src/te/operation/cross_thread_reduction.cc b/src/te/operation/cross_thread_reduction.cc index 2ed5fd4029a2..e419377e7664 100644 --- a/src/te/operation/cross_thread_reduction.cc +++ b/src/te/operation/cross_thread_reduction.cc @@ -134,29 +134,25 @@ Stmt MakeCrossThreadReduction(const ComputeOpNode* self, const Stage& stage, // If we load from and then store into the same res_handles in the thread_allreduce intrinsic, // something goes wrong, so we use an extra variable here for normal reduction. - std::vector normal_res_handles; + std::vector normal_res_buffers; std::vector normal_init, normal_update; if (!normal_red.empty()) { - normal_res_handles.reserve(size); + normal_res_buffers.reserve(size); normal_init.reserve(size); normal_update.resize(size); const CommReducerNode* combiner = reduces[0]->combiner.as(); ICHECK(combiner); Array lhs; for (size_t i = 0; i < size; ++i) { - DataType t = reduces[i]->dtype; - normal_res_handles.emplace_back("normal_reduce_temp" + std::to_string(i), - PointerType(PrimType(t), "local")); - lhs.push_back(Load(t, normal_res_handles[i], 0, const_true(t.lanes()))); + normal_res_buffers.push_back( + decl_buffer({1}, reduces[i]->dtype, "normal_reduce_temp" + std::to_string(i), "local")); + lhs.push_back(BufferLoad(normal_res_buffers[i], {0})); } Array init_value = combiner->identity_element; Array update_value = (*combiner)(lhs, reduces[0]->source); for (size_t i = 0; i < size; ++i) { - DataType t = reduces[i]->dtype; - normal_init.emplace_back( - Store(normal_res_handles[i], init_value[i], 0, const_true(t.lanes()))); - normal_update.emplace_back( - Store(normal_res_handles[i], update_value[i], 0, const_true(t.lanes()))); + normal_init.emplace_back(BufferStore(normal_res_buffers[i], init_value[i], {0})); + normal_update.emplace_back(BufferStore(normal_res_buffers[i], update_value[i], {0})); } } @@ -164,8 +160,7 @@ Stmt MakeCrossThreadReduction(const ComputeOpNode* self, const Stage& stage, freduce_args.push_back(make_const(DataType::UInt(32), static_cast(size))); for (size_t i = 0; i < size; ++i) { if (!normal_red.empty()) { - DataType t = reduces[i]->dtype; - freduce_args.push_back(Load(t, normal_res_handles[i], 0, const_true(t.lanes()))); + freduce_args.push_back(BufferLoad(normal_res_buffers[i], {0})); } else { freduce_args.push_back(reduces[0]->source[i]); } @@ -174,12 +169,15 @@ Stmt MakeCrossThreadReduction(const ComputeOpNode* self, const Stage& stage, // No constraints on the thread reduction step. It may have redundent // computation for rare cases. TODO(tvm-team): revisit this. freduce_args.push_back(const_true(1)); - std::vector res_handles(size); + std::vector res_buffers(size); for (size_t idx = 0; idx < size; ++idx) { - DataType dtype = reduces[idx]->dtype; - res_handles[idx] = - Var("reduce_temp" + std::to_string(idx), PointerType(PrimType(dtype), "local")); - freduce_args.push_back(res_handles[idx]); + res_buffers[idx] = + decl_buffer({1}, reduces[idx]->dtype, "reduce_temp" + std::to_string(idx), "local"); + // Make a BufferLoad object so that we can pass the entire Buffer + // object through to LowerThreadAllreduce. The index here is + // unused. + PrimExpr dummy_load = BufferLoad(res_buffers[idx], {0}); + freduce_args.push_back(dummy_load); } for (IterVar iv : stage->leaf_iter_vars) { @@ -216,18 +214,18 @@ Stmt MakeCrossThreadReduction(const ComputeOpNode* self, const Stage& stage, std::vector assigns(size); for (size_t idx = 0; idx < size; ++idx) { - DataType t = reduces[idx]->dtype; - assigns[idx] = ProducerStore(stage->op.output(idx), - Load(t, res_handles[idx], 0, const_true(t.lanes())), args); + assigns[idx] = ProducerStore(stage->op.output(idx), BufferLoad(res_buffers[idx], {0}), args); } Stmt assign_body = SeqStmt::Flatten(assigns); assign_body = MergeNest(MakeIfNest(output_preds), assign_body); Stmt body = SeqStmt::Flatten(reduce_body, assign_body); for (size_t idx = size; idx != 0; --idx) { - body = Allocate(res_handles[idx - 1], reduces[idx - 1]->dtype, {1}, const_true(), body); + const auto& res_buffer = res_buffers[idx - 1]; + body = Allocate(res_buffer->data, res_buffer->dtype, res_buffer->shape, const_true(), body); if (!normal_red.empty()) { - body = - Allocate(normal_res_handles[idx - 1], reduces[idx - 1]->dtype, {1}, const_true(), body); + const auto& normal_res_buffer = normal_res_buffers[idx - 1]; + body = Allocate(normal_res_buffer->data, normal_res_buffer->dtype, normal_res_buffer->shape, + const_true(), body); } } body = Substitute(body, value_map); diff --git a/src/te/schedule/message_passing.cc b/src/te/schedule/message_passing.cc index d45f29ebc5b6..b1056ac2447d 100644 --- a/src/te/schedule/message_passing.cc +++ b/src/te/schedule/message_passing.cc @@ -79,6 +79,22 @@ void PassUpThreadBinding(const Stage& stage, std::unordered_map* } else if (const RebaseNode* s = rel.as()) { state[s->parent] = state[s->rebased]; } else if (rel.as()) { + } else if (const TransformNode* s = rel.as()) { + // Currently, this marks all original iter vars as deriving from + // a thread bind if any of the transformed variables are bound, + // even if the inverse expression for that iter var doesn't + // depend on the bound variable. + + // TODO(Lunderberg): For each of original variable, check + // whether any variable in the inverse expression for it has a + // thread binding. + bool is_thread_binding = false; + for (const auto& iter_var : s->transformed_variables) { + is_thread_binding = is_thread_binding || state[iter_var]; + } + for (const auto& iter_var : s->original_variables) { + state[iter_var] = is_thread_binding; + } } else { LOG(FATAL) << "unknown relation type"; } @@ -157,6 +173,29 @@ void PassDownDomain(const Stage& stage, std::unordered_map* p_st Update(p_state, r->rebased, Range::FromMinExtent(0, state.at(r->parent)->extent), actx); } else if (const SingletonNode* s = rel.as()) { Update(p_state, s->iter, Range::FromMinExtent(0, 1), actx); + } else if (const TransformNode* s = rel.as()) { + bool missing_originals = false; + for (const auto& iter_var : s->original_variables) { + if (!state.count(iter_var)) { + ICHECK(allow_missing); + missing_originals = true; + } + } + if (missing_originals) { + continue; + } + + Array original_ranges; + for (const auto& iter_var : s->original_variables) { + original_ranges.push_back(state[iter_var]); + } + Array updated_ranges = s->forward_transformation->MapRanges(original_ranges); + + ICHECK_EQ(updated_ranges.size(), s->transformed_variables.size()); + for (size_t i = 0; i < updated_ranges.size(); i++) { + Update(p_state, s->transformed_variables[i], updated_ranges[i], actx); + } + } else { LOG(FATAL) << "unknown relation type"; } @@ -225,6 +264,29 @@ void PassUpIndex(const Stage& stage, const Map& dom_map, state[s->parent] = value; } } else if (rel.as()) { + } else if (const TransformNode* s = rel.as()) { + bool missing_transformed = false; + for (const auto& iter_var : s->transformed_variables) { + if (!state.count(iter_var)) { + ICHECK(allow_missing); + missing_transformed = true; + } + } + if (missing_transformed) { + continue; + } + + Array transformed_indices; + for (const auto& iter_var : s->transformed_variables) { + transformed_indices.push_back(state[iter_var]); + } + Array original_indices = s->inverse_transformation->MapIndices(transformed_indices); + + ICHECK_EQ(original_indices.size(), s->original_variables.size()); + for (size_t i = 0; i < original_indices.size(); i++) { + state[s->original_variables[i]] = original_indices[i]; + } + } else { LOG(FATAL) << "unknown relation type"; } @@ -270,6 +332,28 @@ void PassDownIndex(const Stage& stage, const Map& dom_map, state[s->rebased] = value; } else if (const SingletonNode* s = rel.as()) { state[s->iter] = make_zero(s->iter->var.dtype()); + } else if (const TransformNode* s = rel.as()) { + bool missing_originals = false; + for (const auto& iter_var : s->original_variables) { + if (!state.count(iter_var)) { + ICHECK(allow_missing); + missing_originals = true; + } + } + if (missing_originals) { + continue; + } + + Array original_indices; + for (const auto& iter_var : s->original_variables) { + original_indices.push_back(state[iter_var]); + } + Array transformed_indices = s->forward_transformation->MapIndices(original_indices); + + ICHECK_EQ(transformed_indices.size(), s->transformed_variables.size()); + for (size_t i = 0; i < transformed_indices.size(); i++) { + state[s->transformed_variables[i]] = transformed_indices[i]; + } } else { LOG(FATAL) << "unknown relation type"; } @@ -351,6 +435,26 @@ void PassUpDomain(const RebaseNode* s, const std::unordered_map& *parent = arith::EvalSet(s->rebased->var + parent_min, {{s->rebased, rebased}}); } +Array PassUpDomain(const TransformNode* s, + const std::unordered_map& dom_map, + const Map& transformed_domains) { + Array output; + + Array transformed_indices; + for (const auto& iter_var : s->transformed_variables) { + transformed_indices.push_back(iter_var->var); + } + + Array transformed_exprs = s->inverse_transformation->MapIndices(transformed_indices); + + ICHECK_EQ(transformed_exprs.size(), s->original_variables.size()); + for (size_t i = 0; i < transformed_exprs.size(); i++) { + output.push_back(arith::EvalSet(transformed_exprs[i], transformed_domains)); + } + + return output; +} + void PassUpDomain(const Stage& stage, const std::unordered_map& dom_map, std::unordered_map* p_state) { auto& state = *p_state; @@ -370,6 +474,16 @@ void PassUpDomain(const Stage& stage, const std::unordered_map& PassUpDomain(r, dom_map, state.at(r->rebased), &parent); state[r->parent] = parent; } else if (rel.as()) { + } else if (const TransformNode* r = rel.as()) { + Map transformed_domains; + for (const auto& var : r->transformed_variables) { + transformed_domains.Set(var, state.at(var)); + } + auto original_ranges = PassUpDomain(r, dom_map, transformed_domains); + ICHECK_EQ(original_ranges.size(), r->original_variables.size()); + for (size_t i = 0; i < original_ranges.size(); i++) { + state[r->original_variables[i]] = original_ranges[i]; + } } else { LOG(FATAL) << "unknown relation type"; } @@ -509,6 +623,22 @@ void PassUpBoundCheck(const Stage& s, const Map& dom_map, state[s->parent] = state.at(s->rebased); } else if (rel.as()) { // nop + } else if (const TransformNode* s = rel.as()) { + // Currently, this marks all original iter vars as requiring + // bounds checks if any of the transformed variables require + // bounds checks, even if the inverse expression for that iter + // var doesn't depend on the bound variable. + + // TODO(Lunderberg): For each of original variable, check + // whether any variable in the inverse expression for it + // requires bounds checking. + bool needs_bounds_check = false; + for (const auto& iter_var : s->transformed_variables) { + needs_bounds_check = needs_bounds_check || state[iter_var]; + } + for (const auto& iter_var : s->original_variables) { + state[iter_var] = needs_bounds_check; + } } else { LOG(FATAL) << "unknown relation type"; } diff --git a/src/te/schedule/schedule_lang.cc b/src/te/schedule/schedule_lang.cc index 2f74d2905454..0fcd6133c4a2 100644 --- a/src/te/schedule/schedule_lang.cc +++ b/src/te/schedule/schedule_lang.cc @@ -25,8 +25,10 @@ #include #include +#include #include #include +#include #include "graph.h" @@ -429,6 +431,80 @@ Stage& Stage::rolling_buffer() { self->rolling_buffer = true; return *this; } +Stage& Stage::transform_layout(const Array& initial_indices, + const Array& final_indices, + Array* out_iter_vars) { + StageNode* self = operator->(); + IndexMap map(initial_indices, final_indices); + self->layout_transforms.push_back(map); + + auto* compute = self->op.as(); + + // Can only rewrite the indices of compute op nodes. + if (!compute) { + return *this; + } + + CHECK_EQ(initial_indices.size(), compute->axis.size()) + << "Expected number of initial indices in transformation to match the dimension of " + << self->op->name; + + // Locate the IterVar objects for the data axes. + auto leaf_iter_range = [&]() -> std::pair { + std::vector leaf_var_indices; + for (const auto& axis : compute->axis) { + leaf_var_indices.push_back( + FindLeafVar(self->all_iter_vars.CopyOnWrite(), self->leaf_iter_vars.CopyOnWrite(), axis)); + } + auto minmax_element = std::minmax_element(leaf_var_indices.begin(), leaf_var_indices.end()); + return {*minmax_element.first, *minmax_element.second + 1}; + }(); + CHECK_EQ(leaf_iter_range.first + compute->axis.size(), leaf_iter_range.second) + << "Cannot transform indices if they have already been reordered"; + + // Determine the updated ranges of iteration. + Array initial_ranges; + for (const auto& iter_var : compute->axis) { + initial_ranges.push_back(iter_var->dom); + } + Array final_ranges = map->MapRanges(initial_ranges); + + // Make IterVar objects to represent the new iterations. + auto inverse = map.Inverse(initial_ranges); + Array final_indices_iter; + ICHECK_EQ(inverse->initial_indices.size(), final_ranges.size()); + for (size_t i = 0; i < inverse->initial_indices.size(); i++) { + final_indices_iter.push_back(IterVar(final_ranges[i], inverse->initial_indices[i], kDataPar)); + } + + // Append the new IterVar objects to all_iter_vars + for (const auto& iter_var : final_indices_iter) { + self->all_iter_vars.push_back(iter_var); + } + + // Replace the existing IterVar objects in leaf_iter_vars with the + // new IterVar objects. + self->leaf_iter_vars.erase(self->leaf_iter_vars.begin() + leaf_iter_range.first, + self->leaf_iter_vars.begin() + leaf_iter_range.second); + self->leaf_iter_vars.insert(self->leaf_iter_vars.begin() + leaf_iter_range.first, + final_indices_iter.begin(), final_indices_iter.end()); + + // Define a relationship for each new axis + self->relations.push_back(Transform(compute->axis, final_indices_iter, map, inverse)); + + // Return the iteration variables as an output. + if (out_iter_vars) { + *out_iter_vars = final_indices_iter; + } + + return *this; +} + +Stage& Stage::set_axis_separators(const Array& axis_separators) { + StageNode* self = operator->(); + self->axis_separators = axis_separators; + return *this; +} Stage CopyStage(const Stage& s) { ObjectPtr n = make_object(*s.operator->()); @@ -711,6 +787,16 @@ Singleton::Singleton(IterVar iter) { data_ = std::move(n); } +Transform::Transform(Array original_variables, Array transformed_variables, + IndexMap forward_transformation, IndexMap inverse_transformation) { + auto n = make_object(); + n->original_variables = original_variables; + n->transformed_variables = transformed_variables; + n->forward_transformation = forward_transformation; + n->inverse_transformation = inverse_transformation; + data_ = std::move(n); +} + SpecializedCondition::SpecializedCondition(Array conditions) { ObjectPtr n = make_object(); n->clauses = std::move(conditions); @@ -895,6 +981,16 @@ TVM_REGISTER_GLOBAL("te.StageDoubleBuffer").set_body_method(&Stage::double_buffe TVM_REGISTER_GLOBAL("te.StageRollingBuffer").set_body_method(&Stage::rolling_buffer); +TVM_REGISTER_GLOBAL("te.StageTransformLayout") + .set_body_typed([](Stage stage, const Array& initial_indices, + const Array& final_indices) { + Array new_iter_vars; + stage.transform_layout(initial_indices, final_indices, &new_iter_vars); + return new_iter_vars; + }); + +TVM_REGISTER_GLOBAL("te.StageSetAxisSeparators").set_body_method(&Stage::set_axis_separators); + TVM_REGISTER_GLOBAL("te.ScheduleNormalize").set_body_method(&Schedule::normalize); TVM_REGISTER_GLOBAL("te.ScheduleCreateGroup").set_body_method(&Schedule::create_group); diff --git a/src/te/schedule/schedule_ops.cc b/src/te/schedule/schedule_ops.cc index 99e02ccaf943..47ef4af1c4c4 100644 --- a/src/te/schedule/schedule_ops.cc +++ b/src/te/schedule/schedule_ops.cc @@ -40,12 +40,36 @@ namespace te { using namespace tir; +// Annotate the statement with the layout transforms and axis +// separators of the stage. These annotations are removed during +// SchedulePostProcToPrimFunc. Afterwards, layout transforms are +// specified in the PrimFunc attrs, and the axis_separators are +// specified in the BufferNode. +Stmt WrapLayoutTransformationAttrs(const Stage& stage, Stmt body) { + if (stage->layout_transforms.size()) { + for (int i = 0; i < stage->op->num_outputs(); i++) { + body = AttrStmt(Array{stage->op.output(i), stage->layout_transforms}, + tir::attr::layout_transforms, 1, body); + } + } + + if (stage->axis_separators.size()) { + for (int i = 0; i < stage->op->num_outputs(); i++) { + body = AttrStmt(Array{stage->op.output(i), stage->axis_separators}, + tir::attr::axis_separators, 1, body); + } + } + + return body; +} + Stmt MakePipeline(const Stage& s, const std::unordered_map& dom_map, Stmt consumer, bool debug_keep_trivial_loop) { Stmt producer = s->op->BuildProvide(s, dom_map, debug_keep_trivial_loop); if (s->double_buffer) { producer = AttrStmt(s->op, tir::attr::double_buffer_scope, 1, producer); } + producer = WrapLayoutTransformationAttrs(s, producer); Stmt pipeline = producer; if (consumer.defined() && !is_no_op(consumer)) { @@ -209,6 +233,23 @@ class SchedulePostProc : public StmtExprMutator { return this->VisitStmt(op->body); } } + } else if (op->attr_key == tir::attr::layout_transforms || + op->attr_key == tir::attr::axis_separators) { + auto arr = Downcast>(op->node); + ICHECK_EQ(arr.size(), 2); + + Stmt body = op->body; + + Tensor tensor = Downcast(arr[0]); + auto it = replace_op_.find(tensor->op.get()); + if (it != replace_op_.end()) { + if (it->second.defined()) { + return AttrStmt(Array{it->second.output(tensor->value_index), arr[1]}, + op->attr_key, op->value, this->VisitStmt(op->body)); + } else { + return this->VisitStmt(op->body); + } + } } return StmtExprMutator::VisitStmt_(op); } @@ -349,12 +390,16 @@ Stmt ScheduleOps(Schedule sch, Map dom_map_, bool debug_keep_tri Stage s = sch->stages[i - 1]; ICHECK_NE(s->attach_type, kInline) << "call schedule.normalize before scheduleops"; ICHECK(s->op.defined()); - // no need to specify place holder op. - if (s->op.as()) continue; // Remove grouping sugar, get the real attach spec. Stage attach_spec = s.GetAttachSpec(); - if (scan_init.count(s->op)) { + if (s->op.as()) { + // Placeholders don't need any realize/provide statements, but + // may be annotated with set_physical_layout to indicate the + // physical layout of an input, and must still have the + // attribute given. + body = WrapLayoutTransformationAttrs(s, std::move(body)); + } else if (scan_init.count(s->op)) { ICHECK(body.defined()); InjectScanStep mu(s, scan_init.at(s->op), dom_map, true, debug_keep_trivial_loop); body = mu(std::move(body)); @@ -381,6 +426,7 @@ Stmt ScheduleOps(Schedule sch, Map dom_map_, bool debug_keep_tri << body; } } + SchedulePostProc post_proc; post_proc.Init(sch); return post_proc(std::move(body)); diff --git a/src/te/schedule/schedule_postproc_to_primfunc.cc b/src/te/schedule/schedule_postproc_to_primfunc.cc index 7e8b12b6d61e..0cf6e54391da 100644 --- a/src/te/schedule/schedule_postproc_to_primfunc.cc +++ b/src/te/schedule/schedule_postproc_to_primfunc.cc @@ -42,6 +42,7 @@ #include #include +#include #include #include @@ -55,6 +56,7 @@ Buffer CreateBufferFor(const Tensor& tensor, String storage_scope = "") { name += ".v" + std::to_string(tensor->value_index); } Buffer buffer = decl_buffer(tensor->shape, tensor->dtype, name, storage_scope); + return buffer; } @@ -86,6 +88,17 @@ class TensorToBufferMapper : public StmtExprMutator { Tensor tensor = Downcast(op->node); Buffer buffer = GetOrAllocBuffer(tensor); return AttrStmt(buffer, op->attr_key, op->value, op->body); + } else if (op->attr_key == tir::attr::layout_transforms || + op->attr_key == tir::attr::axis_separators) { + auto arr = Downcast>(op->node); + ICHECK_EQ(arr.size(), 2); + + Stmt body = op->body; + + Tensor tensor = Downcast(arr[0]); + Buffer buffer = GetBuffer(tensor); + + return AttrStmt(Array{buffer, arr[1]}, op->attr_key, 1, body); } else { return ret; } @@ -108,7 +121,7 @@ class TensorToBufferMapper : public StmtExprMutator { auto ret = StmtExprMutator::VisitStmt_(op); op = ret.as(); - return BufferStore(buffer, op->value, op->indices); + return BufferStore(buffer, op->value, GetIndices(op->indices, buffer->shape)); } PrimExpr VisitExpr_(const ProducerLoadNode* op) final { @@ -116,7 +129,7 @@ class TensorToBufferMapper : public StmtExprMutator { op = ret.as(); Tensor tensor = Downcast(op->producer); Buffer buffer = GetBuffer(tensor); - return tir::BufferLoad(buffer, op->indices); + return tir::BufferLoad(buffer, GetIndices(op->indices, buffer->shape)); } private: @@ -134,46 +147,279 @@ class TensorToBufferMapper : public StmtExprMutator { return buffer; } - // maps tensor to buffer. + Array GetIndices(const Array& tensor_indices, + const Array& buffer_shape) { + if (tensor_indices.size() == buffer_shape.size()) { + return tensor_indices; + } else if (tensor_indices.size() == 1) { + // Workaround to support previous behavior of tensor indexing by + // a single index, treating the tensor as if were already + // flattened by a row-major traversal. + PrimExpr unravel = tensor_indices[0]; + Array rev_indices; + for (size_t i = buffer_shape.size(); i > 0; i--) { + PrimExpr dim = buffer_shape[i - 1]; + rev_indices.push_back(indexmod(unravel, dim)); + unravel = indexdiv(unravel, dim); + } + return Array(rev_indices.rbegin(), rev_indices.rend()); + } else { + LOG(FATAL) << "Cannot produce indices for " << buffer_shape.size() + << "-dimensional TIR buffer using " << tensor_indices.size() + << "-dimensional tensor indices."; + return {}; + } + } + + // Maps tensor to buffer. std::unordered_map buffer_map_; }; +/*! Collect the physical layout map of all tensors in the statement. */ +class LayoutTransformAttrUnwrapper : StmtExprMutator { + public: + static tir::PrimFunc Apply(tir::PrimFunc func) { + // Collect the physical layout annotations in the body, which may + // refer to input arguments. + auto layout_map = Collector::Collect(func->body); + + if (layout_map.size()) { + func = WithAttr(std::move(func), "layout_transform_map", layout_map); + + auto write_ptr = func.CopyOnWrite(); + write_ptr->body = LayoutTransformAttrUnwrapper()(func->body); + } + + return func; + } + + LayoutTransformAttrUnwrapper() {} + + Stmt VisitStmt_(const AttrStmtNode* op) final { + auto ret = StmtExprMutator::VisitStmt_(op); + op = ret.as(); + + if (op->attr_key == tir::attr::layout_transforms) { + return op->body; + } else { + return ret; + } + } + + private: + /*! Collect the physical layout information of all tensors in the statement. + * + * Must be done before constructing the buffers, since the + * attributes could either apply to the external buffers or to + * internal allocations. + */ + class Collector : StmtExprVisitor { + public: + static Map> Collect(Stmt stmt) { + Collector collector; + collector(std::move(stmt)); + return std::move(collector.layout_map_); + } + + Collector() {} + + void VisitStmt_(const AttrStmtNode* op) final { + if (op->attr_key == tir::attr::layout_transforms) { + auto arr = Downcast>(op->node); + ICHECK_EQ(arr.size(), 2); + + auto buffer = Downcast(arr[0]); + auto layout_transforms = Downcast>(arr[1]); + layout_map_.Set(buffer, layout_transforms); + } + StmtExprVisitor::VisitStmt_(op); + } + + Map> layout_map_; + }; + + std::unordered_map buffer_remap_; + + Map> layout_map_; +}; + +/*! Move axis_separators from an attribute to a buffer property. */ +class AxisSeparatorsAttrUnwrapper : StmtExprMutator { + public: + static tir::PrimFunc Apply(tir::PrimFunc func) { + // Collect the physical layout annotations in the body, which may + // refer to input arguments. + auto axis_separators_map = Collector::Collect(func->body); + + if (axis_separators_map.size()) { + auto write_ptr = func.CopyOnWrite(); + auto pass = AxisSeparatorsAttrUnwrapper(axis_separators_map); + write_ptr->buffer_map = pass.UpdateExternBufferMap(func->buffer_map); + write_ptr->body = pass(func->body); + } + + return func; + } + + explicit AxisSeparatorsAttrUnwrapper(Map> axis_separators_map) + : axis_separators_map_(axis_separators_map) {} + + Map UpdateExternBufferMap(const Map& orig) { + Map output; + for (const auto& kv : orig) { + output.Set(kv.first, GetRemappedBuffer(kv.second)); + } + return output; + } + + Stmt VisitStmt_(const AttrStmtNode* op) final { + auto ret = StmtExprMutator::VisitStmt_(op); + op = ret.as(); + + if (op->attr_key == tir::attr::axis_separators) { + return op->body; + } else { + return ret; + } + } + + Stmt VisitStmt_(const BufferRealizeNode* op) final { + auto node = Downcast(StmtExprMutator::VisitStmt_(op)); + return VisitBufferAccess(std::move(node)); + } + + Stmt VisitStmt_(const BufferStoreNode* op) final { + auto node = Downcast(StmtExprMutator::VisitStmt_(op)); + return VisitBufferAccess(std::move(node)); + } + + PrimExpr VisitExpr_(const BufferLoadNode* op) final { + auto node = Downcast(StmtExprMutator::VisitExpr_(op)); + return VisitBufferAccess(std::move(node)); + } + + private: + template + Node VisitBufferAccess(Node node) { + Buffer new_buf = GetRemappedBuffer(node->buffer); + if (!node->buffer.same_as(new_buf)) { + auto writer = node.CopyOnWrite(); + writer->buffer = new_buf; + } + return node; + } + + Buffer GetRemappedBuffer(Buffer buf) { + // If this buffer has already been remapped, then return the + // previous value. + auto key = buf.get(); + { + auto it = buffer_remap_.find(key); + if (it != buffer_remap_.end()) { + return it->second; + } + } + + // Otherwise, check if we need to add axis_separators to this + // buffer. + auto lookup = axis_separators_map_.Get(buf); + if (lookup) { + Array axis_separators = lookup.value(); + if (axis_separators.size()) { + auto write_ptr = buf.CopyOnWrite(); + write_ptr->axis_separators = axis_separators; + } + } + + // And cache the result for next time. + buffer_remap_[key] = buf; + + return buf; + } + + /*! Collect the axis separator information of all tensors in the statement. + * + * Must be done before constructing the buffers, since the + * attributes could either apply to the external buffers or to + * internal allocations. + */ + class Collector : StmtExprVisitor { + public: + static Map> Collect(Stmt stmt) { + Collector collector; + collector(std::move(stmt)); + return std::move(collector.axis_separators_map_); + } + + Collector() {} + + void VisitStmt_(const AttrStmtNode* op) final { + if (op->attr_key == tir::attr::axis_separators) { + auto arr = Downcast>(op->node); + ICHECK_EQ(arr.size(), 2); + + auto buffer = Downcast(arr[0]); + auto axis_separators = Downcast>(arr[1]); + axis_separators_map_.Set(buffer, axis_separators); + } + StmtExprVisitor::VisitStmt_(op); + } + + Map> axis_separators_map_; + }; + + std::unordered_map buffer_remap_; + + Map> axis_separators_map_; +}; + PrimFunc SchedulePostProcToPrimFunc(Array arg_list, Stmt body, Optional> extern_buffer_opt) { - std::unordered_map extern_buffer; + std::unordered_map extern_tensor_map; if (extern_buffer_opt.defined()) { auto v = extern_buffer_opt.value(); - extern_buffer = std::unordered_map(v.begin(), v.end()); + extern_tensor_map = std::unordered_map(v.begin(), v.end()); } Array params; Map buffer_map; - for (auto var : arg_list) { - if (auto* n = var.as()) { + for (auto arg : arg_list) { + if (auto* n = arg.as()) { + tir::Var var = GetRef(n); params.push_back(GetRef(n)); - } else if (auto* n = var.as()) { + } else if (auto* n = arg.as()) { te::Tensor tensor = GetRef(n); - ICHECK(!extern_buffer.count(tensor)); + ICHECK(!extern_tensor_map.count(tensor)); tir::Buffer buffer = CreateBufferFor(tensor); tir::Var bptr(buffer->name, PrimType(DataType::Handle())); params.push_back(bptr); buffer_map.Set(bptr, buffer); - extern_buffer[tensor] = buffer; - } else { - tir::Buffer buffer = Downcast(var); + extern_tensor_map[tensor] = buffer; + } else if (auto* n = arg.as()) { + tir::Buffer buffer = GetRef(n); tir::Var bptr(buffer->name, PrimType(DataType::Handle())); params.push_back(bptr); buffer_map.Set(bptr, buffer); + } else { + LOG(FATAL) << "Expected argument to be Var, Tensor, or Buffer, but received " + << arg->GetTypeKey(); } } - body = TensorToBufferMapper(std::move(extern_buffer))(std::move(body)); + body = TensorToBufferMapper(std::move(extern_tensor_map))(std::move(body)); + + PrimFunc func = tir::PrimFunc(params, body, VoidType(), buffer_map); + + func = LayoutTransformAttrUnwrapper::Apply(std::move(func)); + func = AxisSeparatorsAttrUnwrapper::Apply(std::move(func)); + // We mark this PrimFunc as coming from a TE schedule - return WithAttr(tir::PrimFunc(params, body, VoidType(), buffer_map), "from_legacy_te_schedule", - Bool(true)); + func = WithAttr(func, "from_legacy_te_schedule", Bool(true)); + + return func; } TVM_REGISTER_GLOBAL("schedule.SchedulePostProcToPrimFunc") diff --git a/src/tir/analysis/block_access_region_detector.cc b/src/tir/analysis/block_access_region_detector.cc index 3038eca8d338..974f6ecd644f 100644 --- a/src/tir/analysis/block_access_region_detector.cc +++ b/src/tir/analysis/block_access_region_detector.cc @@ -141,14 +141,13 @@ Array BlockReadWriteDetector::CollectOpaques() { void BlockReadWriteDetector::VisitExpr_(const VarNode* op) { UpdateOpaque(GetRef(op)); } void BlockReadWriteDetector::VisitExpr_(const LoadNode* op) { - UpdateOpaque(op->buffer_var); - ExprVisitor::VisitExpr_(op); + LOG(FATAL) << "Unexpected use of deprecated LoadNode. Please use BufferLoadNode instead."; } void BlockReadWriteDetector::VisitExpr_(const BufferLoadNode* op) { std::vector relaxed_region; for (const PrimExpr& index : op->indices) { - relaxed_region.push_back(arith::EvalSet(index, dom_map_)); + relaxed_region.push_back(arith::EvalSet(arith::IntSet::Vector(index), dom_map_)); } Update(&read_buffers_, &read_regions_, op->buffer, relaxed_region); ExprVisitor::VisitExpr_(op); @@ -194,14 +193,13 @@ void BlockReadWriteDetector::VisitExpr_(const CallNode* op) { } void BlockReadWriteDetector::VisitStmt_(const StoreNode* op) { - UpdateOpaque(op->buffer_var); - StmtVisitor::VisitStmt_(op); + LOG(FATAL) << "Unexpected use of deprecated StoreNode. Please use BufferStoreNode instead."; } void BlockReadWriteDetector::VisitStmt_(const BufferStoreNode* op) { std::vector relaxed_region; for (const PrimExpr& index : op->indices) { - relaxed_region.push_back(arith::EvalSet(index, dom_map_)); + relaxed_region.push_back(arith::EvalSet(arith::IntSet::Vector(index), dom_map_)); } Update(&writes_buffers_, &write_regions_, op->buffer, relaxed_region); StmtVisitor::VisitStmt_(op); diff --git a/src/tir/analysis/buffer_access_lca_detector.cc b/src/tir/analysis/buffer_access_lca_detector.cc index e680d689735d..b71e6b27f486 100644 --- a/src/tir/analysis/buffer_access_lca_detector.cc +++ b/src/tir/analysis/buffer_access_lca_detector.cc @@ -43,6 +43,13 @@ class LCADetector : public StmtExprVisitor { detector.buffer_var_map_.emplace(buffer->data.get(), buffer.get()); } + // The root node must be explicitly present in the list of + // ancestor_scopes_. We cannot use nullptr to represent the root + // node, as that is also used to represent a scope that hasn't + // been observed before. + ScopeInfo root(nullptr, nullptr, 0); + detector.ancestor_scopes_.push_back(&root); + detector(func->body); // Prepare the return Map> buffer_lca; @@ -120,13 +127,11 @@ class LCADetector : public StmtExprVisitor { // Explict to visit buffer data in Load and Store node. void VisitExpr_(const LoadNode* op) final { - ExprVisitor::VisitExpr_(op); - VisitBufferVar(op->buffer_var.get()); + LOG(FATAL) << "Unexpected use of deprecated LoadNode. Please use BufferLoadNode instead."; } void VisitStmt_(const StoreNode* op) final { - StmtVisitor::VisitStmt_(op); - VisitBufferVar(op->buffer_var.get()); + LOG(FATAL) << "Unexpected use of deprecated StoreNode. Please use BufferStoreNode instead."; } void VisitBufferVar(const VarNode* op) { @@ -137,6 +142,7 @@ class LCADetector : public StmtExprVisitor { } void UpdateBufferLCA(const BufferNode* buffer) { + buffer_var_map_.emplace(buffer->data.get(), buffer); if (match_buffers_.find(buffer) == match_buffers_.end()) { // Ingore buffer created by block match_buffer const ScopeInfo*& lca = buffer_lca_[buffer]; @@ -169,8 +175,11 @@ class LCADetector : public StmtExprVisitor { return lhs; } - /*! \brief The ancestor scope stacks info (Block and For), initialized with Null. */ - std::vector ancestor_scopes_ = {nullptr}; + /*! \brief The ancestor scope stacks info (Block and For). The + * first element is initialized in LCADetector::Detect to represent + * the root scope. + */ + std::vector ancestor_scopes_ = {}; /*! \brief The map from Buffer to its LCA ForNode/BlockNode. */ std::unordered_map buffer_lca_ = {}; /*! \brief The map from Buffer data to the Buffer. */ diff --git a/src/tir/analysis/device_constraint_utils.cc b/src/tir/analysis/device_constraint_utils.cc index 26cf66c4d4c0..1309681513a9 100644 --- a/src/tir/analysis/device_constraint_utils.cc +++ b/src/tir/analysis/device_constraint_utils.cc @@ -210,6 +210,8 @@ class ApplyDeviceConstraintsMutator : public StmtExprMutator { // Start with a copy of the current prim_func buffer map. Map new_buffer_map(prim_func->buffer_map.begin(), prim_func->buffer_map.end()); + Map new_preflattened_buffer_map(prim_func->preflattened_buffer_map.begin(), + prim_func->preflattened_buffer_map.end()); bool any_change = false; // For each constrained parameter... @@ -223,6 +225,23 @@ class ApplyDeviceConstraintsMutator : public StmtExprMutator { any_change = true; } new_buffer_map.Set(param, new_buffer); + + // Rewrite the pre-flattened buffers to account for constraint. + // This only has an impact if the IRModule being analyzed has + // already been run through the StorageFlatten or FlattenBuffer + // passes. + if (auto opt = prim_func->preflattened_buffer_map.Get(param)) { + Buffer pf_buffer = opt.value(); + if (pf_buffer.same_as(buffer)) { + new_preflattened_buffer_map.Set(param, new_buffer); + } else { + const Buffer new_buffer = RewriteBuffer(pf_buffer, virtual_device); + if (!new_buffer.same_as(pf_buffer)) { + any_change = true; + } + new_preflattened_buffer_map.Set(param, new_buffer); + } + } } // Make sure we have accounted for all prim_func parameters. CheckNoRemainingPointerParams(prim_func, ¤t_primfunc_param_index); @@ -240,7 +259,8 @@ class ApplyDeviceConstraintsMutator : public StmtExprMutator { if (any_change) { return PrimFunc(prim_func->params, std::move(new_body), prim_func->ret_type, - std::move(new_buffer_map), prim_func->attrs, prim_func->span); + std::move(new_buffer_map), std::move(new_preflattened_buffer_map), + prim_func->attrs, prim_func->span); } else { return prim_func; } @@ -425,9 +445,8 @@ class ApplyDeviceConstraintsMutator : public StmtExprMutator { PointerType new_pointer_type(pointer_type_node->element_type, virtual_device->memory_scope); Var new_data(buffer->data->name_hint, new_pointer_type, buffer->data->span); var_subst_.emplace(buffer->data.get(), new_data); - Buffer new_buffer(new_data, buffer->dtype, buffer->shape, buffer->strides, buffer->elem_offset, - buffer->name, buffer->data_alignment, buffer->offset_factor, - buffer->buffer_type, buffer->span); + Buffer new_buffer = buffer; + new_buffer.CopyOnWrite()->data = new_data; buffer_subst_.emplace(buffer.get(), new_buffer); return new_buffer; } diff --git a/src/tir/analysis/var_touch.cc b/src/tir/analysis/var_touch.cc index c4acd2b74aad..f92afc4d15a1 100644 --- a/src/tir/analysis/var_touch.cc +++ b/src/tir/analysis/var_touch.cc @@ -44,13 +44,21 @@ class VarTouchVisitor : public StmtExprVisitor { void VisitExpr_(const VarNode* op) final { Handle(op); } + void VisitExpr_(const LoadNode* op) final { + LOG(FATAL) << "Unexpected use of deprecated LoadNode. Please use BufferLoadNode instead."; + } + void VisitStmt_(const StoreNode* op) final { - Handle(op->buffer_var.get()); + LOG(FATAL) << "Unexpected use of deprecated StoreNode. Please use BufferStoreNode instead."; + } + + void VisitStmt_(const BufferStoreNode* op) final { + Handle(op->buffer->data.get()); StmtVisitor::VisitStmt_(op); } - void VisitExpr_(const LoadNode* op) final { - Handle(op->buffer_var.get()); + void VisitExpr_(const BufferLoadNode* op) final { + Handle(op->buffer->data.get()); ExprVisitor::VisitExpr_(op); } diff --git a/src/tir/analysis/verify_gpu_code.cc b/src/tir/analysis/verify_gpu_code.cc index c1579c21f249..b082581a5148 100644 --- a/src/tir/analysis/verify_gpu_code.cc +++ b/src/tir/analysis/verify_gpu_code.cc @@ -184,7 +184,15 @@ class GPUCodeVerifier : public StmtExprVisitor { StmtVisitor::VisitStmt_(op); } - void VisitExpr_(const LoadNode* op) { + void VisitExpr_(const LoadNode* op) final { + LOG(FATAL) << "Unexpected use of deprecated LoadNode. Please use BufferLoadNode instead."; + } + + void VisitStmt_(const StoreNode* op) final { + LOG(FATAL) << "Unexpected use of deprecated StoreNode. Please use BufferStoreNode instead."; + } + + void VisitExpr_(const BufferLoadNode* op) { if (op->dtype.lanes() > 1) { if (static_cast(op->dtype.lanes() * op->dtype.bytes()) > max_vector_bytes_) { std::stringstream s; @@ -197,7 +205,7 @@ class GPUCodeVerifier : public StmtExprVisitor { ExprVisitor::VisitExpr_(op); } - void VisitStmt_(const StoreNode* op) { + void VisitStmt_(const BufferStoreNode* op) { if (op->value->dtype.lanes() > 1) { if (static_cast(op->value->dtype.lanes() * op->value->dtype.bytes()) > max_vector_bytes_) { diff --git a/src/tir/analysis/verify_memory.cc b/src/tir/analysis/verify_memory.cc index b6c41b958c31..6ee30e04704a 100644 --- a/src/tir/analysis/verify_memory.cc +++ b/src/tir/analysis/verify_memory.cc @@ -89,12 +89,20 @@ class MemoryAccessVerifier final : protected StmtExprVisitor { } void VisitExpr_(const LoadNode* op) final { - HandleLoadStoreToVariable(op->buffer_var); - return StmtExprVisitor::VisitExpr_(op); + LOG(FATAL) << "Unexpected use of deprecated LoadNode. Please use BufferLoadNode instead."; } void VisitStmt_(const StoreNode* op) final { - HandleLoadStoreToVariable(op->buffer_var); + LOG(FATAL) << "Unexpected use of deprecated StoreNode. Please use BufferStoreNode instead."; + } + + void VisitExpr_(const BufferLoadNode* op) final { + HandleLoadStoreToVariable(op->buffer->data); + return StmtExprVisitor::VisitExpr_(op); + } + + void VisitStmt_(const BufferStoreNode* op) final { + HandleLoadStoreToVariable(op->buffer->data); return StmtExprVisitor::VisitStmt_(op); } //@} diff --git a/src/tir/ir/buffer.cc b/src/tir/ir/buffer.cc index 24aacc3c04f7..4fe9b162078e 100644 --- a/src/tir/ir/buffer.cc +++ b/src/tir/ir/buffer.cc @@ -48,10 +48,10 @@ Array SimplifyArray(arith::Analyzer* ana, Array array) { } Buffer decl_buffer(Array shape, DataType dtype, String name, String storage_scope, - Span span) { + Array axis_separators, Span span) { DataType storage_dtype = (dtype == DataType::Bool() ? DataType::Int(8) : dtype); return Buffer(Var(name, PointerType(PrimType(storage_dtype), storage_scope), span), dtype, shape, - Array(), PrimExpr(), name, 0, 0, kDefault, span); + Array(), PrimExpr(), name, 0, 0, kDefault, axis_separators, span); } // Split the given expression w.r.t the add operator @@ -243,82 +243,187 @@ inline PrimExpr MergeMulMod(arith::Analyzer* analyzer, const PrimExpr& base) { return no_opt_sum; } +Array Buffer::OffsetOf(Array input_indices) const { + return (*this)->ElemOffset(std::move(input_indices)); +} + // The buffer offset in convention of number of elements of // original data ignoring number of lanes. // We also perform optimization to simplify the indexing expression. -PrimExpr BufferNode::ElemOffset(Array index) const { - PrimExpr base = this->elem_offset; +Array BufferNode::ElemOffset(Array input_indices) const { + ICHECK_EQ(shape.size(), input_indices.size()) + << "Buffer " << this->name << " is " << shape.size() + << "-dimensional, cannot be indexed with the " << input_indices.size() + << "-dimensional indices provided."; + + if (strides.size()) { + ICHECK_EQ(this->strides.size(), input_indices.size()) + << "If strides are defined, " + << "the index's dimensionality must match the dimensionality of the index given."; + } + + // TODO(Lunderberg): Better handling for cases where there is more + // than one output index. Currently, this only allows elem_offset + // to be non-zero for flat memory allocations. + Array elem_offsets = {}; + if (elem_offset.defined() && !is_zero(elem_offset)) { + elem_offsets = {elem_offset}; + } + + if (elem_offsets.size()) { + ICHECK_EQ(elem_offsets.size(), axis_separators.size() + 1) + << "If element offsets are defined, " + << "there must be one element offset for each output index."; + } + + Array output_indices(axis_separators.size() + 1, 0); + + size_t current_output_axis = 0; + arith::Analyzer ana; - if (this->strides.size() == 0) { - // Scalar case - if (this->shape.size() == 0 && index.size() == 1) { - auto is_int = index[0].as(); - ICHECK(is_int && is_int->value == 0); - base = base + index[0]; - } else { - ICHECK_EQ(this->shape.size(), index.size()); - if (index.size() > 0) { - PrimExpr offset = index[0]; - for (size_t i = 1; i < index.size(); ++i) { - offset = MergeMulMod(&ana, offset * this->shape[i] + index[i]); - } - base = base + offset; - } + + for (size_t i = 0; i < input_indices.size(); i++) { + if ((current_output_axis < axis_separators.size()) && + (i == size_t(axis_separators[current_output_axis]->value))) { + current_output_axis++; } - } else { - ICHECK_EQ(this->strides.size(), index.size()); - if (is_zero(base)) { - base = MergeMulMod(&ana, index[0] * this->strides[0]); + + PrimExpr output_index = output_indices[current_output_axis]; + if (strides.size()) { + output_index = output_index + input_indices[i] * strides[i]; } else { - base = MergeMulMod(&ana, base + index[0] * this->strides[0]); + output_index = output_index * this->shape[i] + input_indices[i]; + } + + if (i > 0) { + output_index = MergeMulMod(&ana, output_index); } - for (size_t i = 1; i < index.size(); ++i) { - base = MergeMulMod(&ana, base + index[i] * this->strides[i]); + + output_indices.Set(current_output_axis, output_index); + } + + if (elem_offsets.size()) { + for (size_t i = 0; i < output_indices.size(); i++) { + output_indices.Set(i, output_indices[i] + elem_offsets[i]); } } - return base; + + return SimplifyArray(&ana, output_indices); } -inline PrimExpr BufferOffset(const BufferNode* n, Array index, DataType dtype) { - PrimExpr offset = n->ElemOffset(index); +inline Array BufferOffset(const BufferNode* n, Array index, DataType dtype) { + Array offsets = n->ElemOffset(index); + // If the Buffer has element type with more than one lane, scale to + // get the offset in number of scalars. if (n->dtype.lanes() != 1) { - offset = offset * make_const(offset.dtype(), dtype.lanes()); + PrimExpr last_offset = offsets[offsets.size() - 1]; + offsets.Set(offsets.size() - 1, last_offset * make_const(last_offset.dtype(), dtype.lanes())); } + + // If the requested type has more than one lane, make a RampNode at + // that offset. if (dtype.lanes() != 1) { - return tir::Ramp(offset, make_const(offset.dtype(), 1), dtype.lanes()); + PrimExpr last_offset = offsets[offsets.size() - 1]; + PrimExpr stride = make_const(last_offset.dtype(), 1); + offsets.Set(offsets.size() - 1, tir::Ramp(last_offset, stride, dtype.lanes())); + } + + return offsets; +} + +Buffer Buffer::GetFlattenedBuffer() const { + auto self = operator->(); + + // These checks ensure that all output axes contain at least one + // input axis. + for (size_t i = 0; (i + 1) < self->axis_separators.size(); i++) { + auto sep = self->axis_separators[i]->value; + auto next_sep = self->axis_separators[i]->value; + ICHECK_LT(sep, next_sep) << "Axis separators must be in strictly increasing order."; + } + if (self->axis_separators.size()) { + auto first_sep = self->axis_separators[0]->value; + ICHECK_GT(first_sep, 0) << "First axis separator must be strictly greater than 0, " + << "so that first output axis contains at least one input axis"; + auto last_sep = self->axis_separators[self->axis_separators.size() - 1]->value; + ICHECK_LT(last_sep, self->shape.size()) + << "Last output axis must contain at least one input axis."; + } + + Array output_shape; + if (self->strides.size()) { + // If strides are defined, then the extent of each flattened + // buffer is the stride*size for the first input axis used for + // each output axis. + ICHECK_EQ(self->shape.size(), self->strides.size()); + output_shape.push_back(self->strides[0] * self->shape[0]); + for (const auto& sep : self->axis_separators) { + output_shape.push_back(self->strides[sep->value] * self->shape[sep->value]); + } + } else { - return offset; + // Otherwise, the extent of each flattened buffer is the product + // of the extents of each input axis used to generate that output + // axis. This also "flattens" rank-0 tensors to a rank-1 buffer + // of shape [1]. + output_shape = Array(self->axis_separators.size() + 1, 1); + size_t current_output_index = 0; + for (size_t i = 0; i < self->shape.size(); i++) { + if ((current_output_index < self->axis_separators.size()) && + (i == size_t(self->axis_separators[current_output_index]->value))) { + current_output_index += 1; + } + output_shape.Set(current_output_index, output_shape[current_output_index] * self->shape[i]); + } } + + // The axis_separators for the output buffer. + Array output_axis_separators; + for (size_t i = 0; i < self->axis_separators.size(); i++) { + auto dtype = self->axis_separators[i]->dtype; + output_axis_separators.push_back(IntImm(dtype, i + 1)); + } + + Buffer output = *this; + auto writer = output.CopyOnWrite(); + writer->shape = output_shape; + writer->axis_separators = output_axis_separators; + writer->strides = {}; + + return output; } -PrimExpr Buffer::vload(Array begin, DataType dtype) const { +PrimExpr Buffer::vload(Array begin, DataType value_dtype) const { // specially handle bool, stored as DataType::Int(8) const BufferNode* n = operator->(); ICHECK(n != nullptr); - ICHECK(dtype.element_of() == n->dtype.element_of() && dtype.lanes() % n->dtype.lanes() == 0) - << "Cannot load " << dtype << " from buffer of " << n->dtype; - if (dtype == DataType::Bool()) { - return tir::Cast(DataType::Bool(), - tir::Load(DataType::Int(8), n->data, BufferOffset(n, begin, DataType::Int(8)), - const_true())); - } else { - return tir::Load(dtype, n->data, BufferOffset(n, begin, dtype), const_true(dtype.lanes())); + ICHECK(value_dtype.element_of() == n->dtype.element_of() && + value_dtype.lanes() % n->dtype.lanes() == 0) + << "Cannot load " << value_dtype << " from buffer of " << n->dtype; + + Array indices = begin; + int factor = value_dtype.lanes() / n->dtype.lanes(); + if (factor > 1) { + indices.Set(indices.size() - 1, Ramp(indices[indices.size() - 1], 1, factor)); } + return BufferLoad(*this, indices); } Stmt Buffer::vstore(Array begin, PrimExpr value) const { // specially handle bool, stored as DataType::Int(8) const BufferNode* n = operator->(); ICHECK(n != nullptr); - DataType dtype = value.dtype(); - ICHECK(dtype.element_of() == n->dtype.element_of() && dtype.lanes() % n->dtype.lanes() == 0) - << "Cannot store " << dtype << " to buffer of " << n->dtype; - if (value.dtype() == DataType::Bool()) { - return tir::Store(n->data, tir::Cast(DataType::Int(8), value), - BufferOffset(n, begin, DataType::Int(8)), const_true()); - } else { - return tir::Store(n->data, value, BufferOffset(n, begin, dtype), const_true(dtype.lanes())); + DataType value_dtype = value.dtype(); + ICHECK(value_dtype.element_of() == n->dtype.element_of() && + value_dtype.lanes() % n->dtype.lanes() == 0) + << "Cannot store " << value_dtype << " to buffer of " << n->dtype; + + Array indices = begin; + int factor = value_dtype.lanes() / n->dtype.lanes(); + if (factor > 1) { + indices.Set(indices.size() - 1, Ramp(indices[indices.size() - 1], 1, factor)); } + return BufferStore(*this, value, indices); } String Buffer::scope() const { @@ -353,7 +458,10 @@ Buffer Buffer::MakeSlice(Array begins, Array extents) const ICHECK(n != nullptr); arith::Analyzer ana; begins = SimplifyArray(&ana, begins); - PrimExpr elem_offset = ana.Simplify(n->ElemOffset(begins)); + Array elem_offset = n->ElemOffset(begins); + elem_offset.MutateByApply([&](const PrimExpr& expr) { return ana.Simplify(expr); }); + ICHECK_EQ(elem_offset.size(), 1) << "MakeSlice currently supports only flat 1-d memory."; + Array strides = n->strides; if (strides.size() == 0) { bool can_relax = true; @@ -372,7 +480,7 @@ Buffer Buffer::MakeSlice(Array begins, Array extents) const return MakeStrideView().MakeSlice(begins, extents); } } - return Buffer(n->data, n->dtype, extents, strides, elem_offset, n->name + "_slice", + return Buffer(n->data, n->dtype, extents, strides, elem_offset[0], n->name + "_slice", n->data_alignment, 0, n->buffer_type); } @@ -407,15 +515,27 @@ PrimExpr Buffer::access_ptr(int access_mask, DataType ptr_type, int content_lane Buffer::Buffer(Var data, DataType dtype, Array shape, Array strides, PrimExpr elem_offset, String name, int data_alignment, int offset_factor, - BufferType buffer_type, Span span) { + BufferType buffer_type, Array axis_separators, Span span) { DataType storage_dtype = dtype; // specially handle bool if (storage_dtype == DataType::Bool()) { storage_dtype = DataType::Int(8); } - ICHECK(IsPointerType(data->type_annotation, storage_dtype)) - << "Buffer data field expect to have the right pointer type annotation" - << " annotation=" << data->type_annotation << ", storage_dtype=" << storage_dtype; + // The buffer dtype may differ from the dtype of the underlying + // allocation, such as a single allocation that backs multiple + // tensors without a common datatype. Therefore, we check that the + // data pointer is a pointer, but not the exact type of the + // pointed-to values. + + // TODO(Lunderberg): Use an explicit pointer cast for the data + // pointer. Should be done alongside extensions to StmtExprMutator + // to more easily handle buffer/buffer_var updates. + ICHECK(data->type_annotation.defined()) + << "Variable " << data->name_hint << " is missing a type annotation."; + ICHECK(data->type_annotation.as()) + << "Variable " << data->name_hint << " is not a pointer."; + ICHECK(data->type_annotation.as()->element_type.as()) + << "Variable " << data->name_hint << " does not point to a primitive."; auto n = make_object(); n->data = std::move(data); @@ -423,6 +543,7 @@ Buffer::Buffer(Var data, DataType dtype, Array shape, Array n->shape = std::move(shape); n->strides = std::move(strides); + n->axis_separators = std::move(axis_separators); n->name = std::move(name); if (!elem_offset.defined()) { elem_offset = make_const(n->DefaultIndexType(), 0); @@ -455,15 +576,19 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) TVM_REGISTER_NODE_TYPE(BufferNode); TVM_REGISTER_GLOBAL("tir.Buffer").set_body([](TVMArgs args, TVMRetValue* ret) { - ICHECK_EQ(args.size(), 10); + ICHECK_EQ(args.size(), 11); auto buffer_type = args[8].operator String(); BufferType type = (buffer_type == "auto_broadcast") ? kAutoBroadcast : kDefault; - *ret = - Buffer(args[0], args[1], args[2], args[3], args[4], args[5], args[6], args[7], type, args[9]); + *ret = Buffer(args[0], args[1], args[2], args[3], args[4], args[5], args[6], args[7], type, + args[9], args[10]); }); TVM_REGISTER_GLOBAL("tir.BufferAccessPtr").set_body_method(&Buffer::access_ptr); +TVM_REGISTER_GLOBAL("tir.BufferGetFlattenedBuffer").set_body_method(&Buffer::GetFlattenedBuffer); + +TVM_REGISTER_GLOBAL("tir.BufferOffsetOf").set_body_method(&Buffer::OffsetOf); + TVM_REGISTER_GLOBAL("tir.BufferVLoad").set_body_method(&Buffer::vload); TVM_REGISTER_GLOBAL("tir.BufferVStore").set_body_method(&Buffer::vstore); diff --git a/src/tir/ir/expr.cc b/src/tir/ir/expr.cc index fbbd4a9522eb..ef533ef84b85 100644 --- a/src/tir/ir/expr.cc +++ b/src/tir/ir/expr.cc @@ -626,6 +626,8 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) // Load Load::Load(DataType dtype, Var buffer_var, PrimExpr index, PrimExpr predicate, Span span) { + LOG(FATAL) << "Unexpected use of deprecated Store node for buffer " << buffer_var->name_hint + << ". Use BufferStore instead."; ICHECK(buffer_var.defined()); ICHECK(predicate.defined()); ICHECK(index.defined()); @@ -1056,12 +1058,28 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { p->stream << "?"; }); // BufferLoad +void BufferLoadNode::LegalizeDType() { + int index_lanes = 1; + for (const auto& index : indices) { + index_lanes *= index.dtype().lanes(); + } + + int buffer_lanes = buffer->dtype.lanes(); + + this->dtype = buffer->dtype.with_lanes(index_lanes * buffer_lanes); +} + BufferLoad::BufferLoad(Buffer buffer, Array indices, Span span) { + ICHECK_EQ(buffer->shape.size(), indices.size()) + << "Buffer " << buffer->name << " is " << buffer->shape.size() + << "-dimensional, cannot be indexed with the " << indices.size() + << "-dimensional indices provided."; + ObjectPtr node = make_object(); - node->dtype = buffer->dtype; node->buffer = std::move(buffer); node->indices = std::move(indices); node->span = std::move(span); + node->LegalizeDType(); data_ = std::move(node); } diff --git a/src/tir/ir/expr_functor.cc b/src/tir/ir/expr_functor.cc index 4c5ea5bfd2d0..c8dc84695b4f 100644 --- a/src/tir/ir/expr_functor.cc +++ b/src/tir/ir/expr_functor.cc @@ -35,8 +35,7 @@ void ExprVisitor::VisitExpr_(const SizeVarNode* op) { void ExprVisitor::VisitExpr_(const AnyNode* op) {} void ExprVisitor::VisitExpr_(const LoadNode* op) { - this->VisitExpr(op->index); - this->VisitExpr(op->predicate); + LOG(FATAL) << "Unexpected use of deprecated LoadNode. Please use BufferLoadNode instead."; } void ExprVisitor::VisitExpr_(const BufferLoadNode* op) { @@ -127,13 +126,8 @@ PrimExpr ExprMutator::VisitExpr_(const SizeVarNode* op) { PrimExpr ExprMutator::VisitExpr_(const AnyNode* op) { return GetRef(op); } PrimExpr ExprMutator::VisitExpr_(const LoadNode* op) { - PrimExpr index = this->VisitExpr(op->index); - PrimExpr predicate = this->VisitExpr(op->predicate); - if (index.same_as(op->index) && predicate.same_as(op->predicate)) { - return GetRef(op); - } else { - return Load(op->dtype, op->buffer_var, index, predicate); - } + LOG(FATAL) << "Unexpected use of deprecated LoadNode. Please use BufferLoadNode instead."; + return PrimExpr(); } PrimExpr ExprMutator::VisitExpr_(const BufferLoadNode* op) { diff --git a/src/tir/ir/function.cc b/src/tir/ir/function.cc index f58dd8aa820c..b9c3029d3c25 100644 --- a/src/tir/ir/function.cc +++ b/src/tir/ir/function.cc @@ -29,7 +29,9 @@ namespace tvm { namespace tir { // Get the function type of a PrimFunc PrimFunc::PrimFunc(Array params, Stmt body, Type ret_type, - Map buffer_map, DictAttrs attrs, Span span) { + Map buffer_map, + Optional> preflattened_buffer_map, DictAttrs attrs, + Span span) { // Assume void-return type for now // TODO(tvm-team) consider type deduction from body. if (!ret_type.defined()) { @@ -40,6 +42,7 @@ PrimFunc::PrimFunc(Array params, Stmt body, Type ret_type, n->body = std::move(body); n->ret_type = std::move(ret_type); n->buffer_map = std::move(buffer_map); + n->preflattened_buffer_map = preflattened_buffer_map.value_or(Map()); n->attrs = std::move(attrs); n->checked_type_ = n->func_type_annotation(); n->span = std::move(span); @@ -118,8 +121,9 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) TVM_REGISTER_GLOBAL("tir.PrimFunc") .set_body_typed([](Array params, Stmt body, Type ret_type, - Map buffer_map, DictAttrs attrs, Span span) { - return PrimFunc(params, body, ret_type, buffer_map, attrs, span); + Map buffer_map, + Map preflattened_buffer_map, DictAttrs attrs, Span span) { + return PrimFunc(params, body, ret_type, buffer_map, preflattened_buffer_map, attrs, span); }); TVM_REGISTER_GLOBAL("tir.TensorIntrin") diff --git a/src/tir/ir/index_map.cc b/src/tir/ir/index_map.cc new file mode 100644 index 000000000000..ba0998e84ffc --- /dev/null +++ b/src/tir/ir/index_map.cc @@ -0,0 +1,154 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file index_map.cc + */ + +#include "tvm/tir/index_map.h" + +#include +#include +#include +#include + +#include + +namespace tvm { +namespace tir { + +IndexMap::IndexMap(Array initial_indices, Array final_indices) { + auto n = make_object(); + n->initial_indices = std::move(initial_indices); + n->final_indices = std::move(final_indices); + data_ = std::move(n); +} + +IndexMap IndexMap::Inverse(Array initial_ranges) const { + // Dummy variables to represent the inverse's inputs. + Array output_vars; + for (size_t i = 0; i < (*this)->final_indices.size(); i++) { + PrimExpr index = (*this)->final_indices[i]; + // TODO(Lunderberg): Better names for these variables. A variable + // that is passed through unmodified (`index` is an element of + // `initial_indices`) should use that input index's name. A pair + // of output indices variables split from a single input index + // should be named (X.outer,X.inner). + std::stringstream ss; + ss << "axis" << i; + Var var_index(ss.str(), index.dtype()); + output_vars.push_back(var_index); + } + + // Dummy ranges for the extent of each input. + Map input_iters; + ICHECK_EQ((*this)->initial_indices.size(), initial_ranges.size()); + for (size_t i = 0; i < initial_ranges.size(); i++) { + input_iters.Set((*this)->initial_indices[i], initial_ranges[i]); + } + + // Unpack the output indices into linear combinations of the initial + // indices. + arith::Analyzer analyzer; + auto diagnostics = DiagnosticContext::Default(IRModule()); + auto iter_map = + DetectIterMap((*this)->final_indices, input_iters, 1, true, &analyzer, diagnostics); + CHECK(iter_map.size()) << "Index transformation was not bijective."; + + // Determine expressions for the input variables, in terms of the + // output variables. + Map inverse_exprs_map = + InverseAffineIterMap(iter_map, Array(output_vars.begin(), output_vars.end())); + + // Unpack the map to an array, maintaining the same parameter order. + Array inverse_exprs; + for (const auto& index : (*this)->initial_indices) { + inverse_exprs.push_back(inverse_exprs_map.at(index)); + } + + return IndexMap(output_vars, inverse_exprs); +} + +Array IndexMapNode::MapIndices(const Array& indices) const { + ICHECK_EQ(indices.size(), initial_indices.size()); + + arith::Analyzer analyzer; + + for (size_t i = 0; i < initial_indices.size(); i++) { + analyzer.Bind(initial_indices[i], indices[i]); + } + + Array output; + for (const auto& output_dim : final_indices) { + output.push_back(analyzer.Simplify(output_dim)); + } + + return output; +} + +Array IndexMapNode::MapRanges(const Array& ranges) const { + ICHECK_EQ(ranges.size(), initial_indices.size()); + + Map input_iters; + for (size_t i = 0; i < initial_indices.size(); i++) { + input_iters.Set(initial_indices[i], ranges[i]); + } + + std::unordered_map dom_map; + for (size_t i = 0; i < initial_indices.size(); i++) { + dom_map[initial_indices[i].get()] = arith::IntSet::FromRange(ranges[i]); + } + + Array output; + for (const auto& final_index : final_indices) { + auto int_set = arith::EvalSet(final_index, dom_map); + output.push_back(Range::FromMinExtent(int_set.min(), int_set.max() - int_set.min() + 1)); + } + + return output; +} + +Array IndexMapNode::MapShape(const Array& shape) const { + ICHECK_EQ(shape.size(), initial_indices.size()); + + Array ranges; + for (auto& dim : shape) { + ranges.push_back(Range(0, dim)); + } + Array mapped = MapRanges(std::move(ranges)); + + Array output; + for (auto& range : mapped) { + ICHECK(is_zero(range->min)); + output.push_back(range->extent); + } + + return output; +} + +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) + .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { + auto* op = static_cast(node.get()); + p->stream << "index_map(" << op->initial_indices << ", " << op->final_indices << ")"; + }); + +TVM_REGISTER_NODE_TYPE(IndexMapNode); + +} // namespace tir +} // namespace tvm diff --git a/src/tir/ir/stmt.cc b/src/tir/ir/stmt.cc index 1269607fd334..3914f41e4f34 100644 --- a/src/tir/ir/stmt.cc +++ b/src/tir/ir/stmt.cc @@ -241,6 +241,8 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) // Store Store::Store(Var buffer_var, PrimExpr value, PrimExpr index, PrimExpr predicate, Span span) { + LOG(FATAL) << "Unexpected use of deprecated Store node for buffer " << buffer_var->name_hint + << ". Use BufferStore instead."; ICHECK(value.defined()); ICHECK(index.defined()); ICHECK(predicate.defined()); @@ -341,7 +343,8 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) // Allocate Allocate::Allocate(Var buffer_var, DataType dtype, Array extents, PrimExpr condition, Stmt body, Map annotations, Span span) { - CHECK(IsPointerType(buffer_var->type_annotation, dtype)) + CHECK(IsPointerType(buffer_var->type_annotation, dtype) || + (dtype.is_bool() && IsPointerType(buffer_var->type_annotation, DataType::Int(8)))) << "The allocated data type (" << dtype << ") does not match the type annotation of the buffer " << buffer_var << " (" << buffer_var->type_annotation @@ -668,6 +671,11 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) // BufferStore BufferStore::BufferStore(Buffer buffer, PrimExpr value, Array indices, Span span) { + ICHECK_EQ(buffer->shape.size(), indices.size()) + << "Buffer " << buffer->name << " is " << buffer->shape.size() + << "-dimensional, cannot be indexed with the " << indices.size() + << "-dimensional indices provided."; + ObjectPtr node = make_object(); node->buffer = std::move(buffer); node->value = std::move(value); @@ -760,7 +768,12 @@ BufferRegion BufferRegion::FullRegion(Buffer buffer) { BufferRegion BufferRegion::FromPoint(Buffer buffer, Array indices) { Array region; for (const PrimExpr& index : indices) { - region.push_back(Range::FromMinExtent(index, 1)); + if (const RampNode* ramp_index = index.as()) { + region.push_back( + Range::FromMinExtent(ramp_index->base, ramp_index->stride * ramp_index->lanes)); + } else { + region.push_back(Range::FromMinExtent(index, 1)); + } } return BufferRegion(buffer, region); } diff --git a/src/tir/ir/stmt_functor.cc b/src/tir/ir/stmt_functor.cc index 949e8a1312aa..c4d7ad0f6c67 100644 --- a/src/tir/ir/stmt_functor.cc +++ b/src/tir/ir/stmt_functor.cc @@ -64,9 +64,7 @@ void StmtVisitor::VisitStmt_(const AllocateConstNode* op) { } void StmtVisitor::VisitStmt_(const StoreNode* op) { - this->VisitExpr(op->value); - this->VisitExpr(op->index); - this->VisitExpr(op->predicate); + LOG(FATAL) << "Unexpected use of deprecated StoreNode. Please use BufferStoreNode instead."; } void StmtVisitor::VisitStmt_(const BufferStoreNode* op) { @@ -358,18 +356,8 @@ Stmt StmtMutator::VisitStmt_(const IfThenElseNode* op) { } Stmt StmtMutator::VisitStmt_(const StoreNode* op) { - PrimExpr value = this->VisitExpr(op->value); - PrimExpr index = this->VisitExpr(op->index); - PrimExpr predicate = this->VisitExpr(op->predicate); - if (value.same_as(op->value) && index.same_as(op->index) && predicate.same_as(op->predicate)) { - return GetRef(op); - } else { - auto n = CopyOnWrite(op); - n->value = std::move(value); - n->index = std::move(index); - n->predicate = std::move(predicate); - return Stmt(n); - } + LOG(FATAL) << "Unexpected use of deprecated StoreNode. Please use BufferStoreNode instead."; + return Stmt(); } Stmt StmtMutator::VisitStmt_(const BufferStoreNode* op) { @@ -664,23 +652,51 @@ class IRSubstitute : public StmtExprMutator { } PrimExpr VisitExpr_(const LoadNode* op) final { - PrimExpr ret = StmtExprMutator::VisitExpr_(op); - op = ret.as(); - if (auto mapped_var = vmap_(op->buffer_var)) { - return Load(op->dtype, Downcast(mapped_var.value()), op->index, op->predicate); - } else { - return ret; - } + LOG(FATAL) << "Unexpected use of deprecated LoadNode. Please use BufferLoadNode instead."; + return PrimExpr(); } Stmt VisitStmt_(const StoreNode* op) final { - Stmt ret = StmtExprMutator::VisitStmt_(op); - op = ret.as(); - if (auto mapped_var = vmap_(op->buffer_var)) { - return Store(Downcast(mapped_var.value()), op->value, op->index, op->predicate); - } else { - return ret; + LOG(FATAL) << "Unexpected use of deprecated StoreNode. Please use BufferStoreNode instead."; + return Stmt(); + } + + PrimExpr VisitExpr_(const BufferLoadNode* op) final { + auto node = Downcast(StmtExprMutator::VisitExpr_(op)); + return VisitBufferAccess(std::move(node)); + } + + Stmt VisitStmt_(const BufferStoreNode* op) final { + auto node = Downcast(StmtExprMutator::VisitStmt_(op)); + return VisitBufferAccess(std::move(node)); + } + + template + Node VisitBufferAccess(Node node) { + Buffer new_buf = GetRemappedBuffer(node->buffer); + + if (!new_buf.same_as(node->buffer)) { + auto writer = node.CopyOnWrite(); + writer->buffer = new_buf; } + + return node; + } + + Buffer GetRemappedBuffer(Buffer buf) { + auto key = buf.get(); + auto it = buf_remap_.find(key); + if (it != buf_remap_.end()) { + return it->second; + } + + if (auto mapped_var = vmap_(buf->data)) { + auto writer = buf.CopyOnWrite(); + writer->data = Downcast(mapped_var); + } + + buf_remap_[key] = buf; + return buf; } Stmt VisitStmt_(const AttrStmtNode* op) final { @@ -696,7 +712,17 @@ class IRSubstitute : public StmtExprMutator { } private: + // Caller provided function that defines the variables to be remapped. std::function(const Var&)> vmap_; + + /* \brief Generated map to track buffers being remapped. + * + * If a `Var BufferNode::data` is remapped, then all buffers + * containing that data pointer should also be remapped. This map + * is used to track buffer modifications, and ensure all instances + * of a buffer are replaced by the same modified buffer object. + */ + std::unordered_map buf_remap_; }; Stmt Substitute(Stmt stmt, std::function(const Var&)> vmap) { diff --git a/src/tir/schedule/primitive/cache_read_write.cc b/src/tir/schedule/primitive/cache_read_write.cc index 6231bb229bf9..ed3ececcebfb 100644 --- a/src/tir/schedule/primitive/cache_read_write.cc +++ b/src/tir/schedule/primitive/cache_read_write.cc @@ -456,13 +456,9 @@ class CacheReadRewriter : public StmtExprMutator { return ExprMutator::VisitExpr_(load); } - PrimExpr VisitExpr_(const LoadNode* load) final { - if (load->buffer_var.same_as(info_->read_buffer->data)) { - ObjectPtr n = make_object(*load); - n->buffer_var = info_->write_buffer->data; - return PrimExpr(n); - } - return ExprMutator::VisitExpr_(load); + PrimExpr VisitExpr_(const LoadNode* op) final { + LOG(FATAL) << "Unexpected use of deprecated LoadNode. Please use BufferLoadNode instead."; + return PrimExpr(); } PrimExpr VisitExpr_(const VarNode* op) final { @@ -575,22 +571,14 @@ class CacheWriteRewriter : public StmtExprMutator { return ExprMutator::VisitExpr_(load); } - PrimExpr VisitExpr_(const LoadNode* load) final { - if (load->buffer_var.same_as(info_->write_buffer->data)) { - ObjectPtr n = make_object(*load); - n->buffer_var = info_->read_buffer->data; - return PrimExpr(n); - } - return ExprMutator::VisitExpr_(load); + PrimExpr VisitExpr_(const LoadNode* op) final { + LOG(FATAL) << "Unexpected use of deprecated LoadNode. Please use BufferLoadNode instead."; + return PrimExpr(); } - Stmt VisitStmt_(const StoreNode* store) final { - if (store->buffer_var.same_as(info_->write_buffer->data)) { - ObjectPtr n = make_object(*store); - n->buffer_var = info_->read_buffer->data; - return Stmt(n); - } - return StmtMutator::VisitStmt_(store); + Stmt VisitStmt_(const StoreNode* op) final { + LOG(FATAL) << "Unexpected use of deprecated StoreNode. Please use BufferStoreNode instead."; + return Stmt(); } PrimExpr VisitExpr_(const VarNode* op) final { diff --git a/src/tir/schedule/primitive/compute_inline.cc b/src/tir/schedule/primitive/compute_inline.cc index 9a9860b42bc6..d7556ed73995 100644 --- a/src/tir/schedule/primitive/compute_inline.cc +++ b/src/tir/schedule/primitive/compute_inline.cc @@ -200,14 +200,14 @@ class BaseInliner : public StmtExprMutator { return StmtExprMutator::VisitExpr_(var); } - PrimExpr VisitExpr_(const LoadNode* load) final { - CheckOpaqueAccess(load->buffer_var.get()); - return StmtExprMutator::VisitExpr_(load); + PrimExpr VisitExpr_(const LoadNode* op) final { + LOG(FATAL) << "Unexpected use of deprecated LoadNode. Please use BufferLoadNode instead."; + return PrimExpr(); } - Stmt VisitStmt_(const StoreNode* store) final { - CheckOpaqueAccess(store->buffer_var.get()); - return StmtExprMutator::VisitStmt_(store); + Stmt VisitStmt_(const StoreNode* op) final { + LOG(FATAL) << "Unexpected use of deprecated StoreNode. Please use BufferStoreNode instead."; + return Stmt(); } Stmt VisitStmt_(const ForNode* loop) final { @@ -284,6 +284,31 @@ class BaseInliner : public StmtExprMutator { } } + /*! + * \brief Count the number of undefined variables that are not used + * as buffer objects. + * + * This is used to determine whether inlining or reverse inlining is + * possible. The only undefined variables present should be the + * load/store indices, or buffer access based on those indices. + * + * \param stmt The statement in which to count undefined variables + */ + static int GetNumUndefinedNonpointerVars(const Stmt& stmt) { + auto undefined_vars = UndefinedVars(stmt, {}); + // Buffer pointers and the inlined indices are allowed, but no + // other variables may appear in the inlined block. + int num_nonpointer_vars = 0; + for (const auto& var : undefined_vars) { + bool is_pointer = var->dtype.is_handle() && var->type_annotation.defined() && + var->type_annotation.as(); + if (!is_pointer) { + num_nonpointer_vars++; + } + } + return num_nonpointer_vars; + } + private: /*! * \brief Add the buffers in the block signature to the `buffer_var_map_`, @@ -417,7 +442,8 @@ class ComputeInliner : public BaseInliner { if (inlined_store_ == nullptr) { return false; } - int n_vars = UndefinedVars(GetRef(inlined_store_), {}).size(); + + int n_vars = GetNumUndefinedNonpointerVars(GetRef(inlined_store_)); if (!UpdateAndCheckIndexVars(inlined_store_->indices, n_vars)) { return false; } @@ -484,7 +510,7 @@ class ReverseComputeInliner : public BaseInliner { // Failure: no BufferLoad from the `inlined_buffer_` return false; } - int n_vars = UndefinedVars(GetRef(inlined_store_), {}).size(); + int n_vars = GetNumUndefinedNonpointerVars(GetRef(inlined_store_)); for (const BufferLoadNode* load : loads) { if (!UpdateAndCheckIndexVars(load->indices, n_vars)) { // Failure: incorrect of inconsistent index vars diff --git a/src/tir/transforms/arg_binder.cc b/src/tir/transforms/arg_binder.cc index 1e566a980463..d7cd731a3d2b 100644 --- a/src/tir/transforms/arg_binder.cc +++ b/src/tir/transforms/arg_binder.cc @@ -154,23 +154,34 @@ void ArgBinder::BindDLTensor(const Buffer& buffer, const PrimExpr& device_type, const Stmt nop = Evaluate(0); // dimension checks PrimExpr v_ndim = TVMArrayGet(tvm_ndim_type, handle, builtin::kArrNDim); + + // Helper functions for shape/stride name formatting + auto shape_handle_name = [&]() { return arg_name + ".shape"; }; + auto stride_handle_name = [&]() { return arg_name + ".strides"; }; + auto array_element_name = [&](const std::string& arr_name, size_t k) { + std::stringstream ss; + ss << arr_name << '[' << k << ']'; + return ss.str(); + }; + auto shape_element_name = [&](size_t k) { return array_element_name(shape_handle_name(), k); }; + auto stride_element_name = [&](size_t k) { return array_element_name(stride_handle_name(), k); }; + PrimExpr a_ndim = make_const(tvm_ndim_type, static_cast(buffer->shape.size())); std::ostringstream ndim_err_msg; ndim_err_msg << arg_name << ".ndim is expected to equal " << buffer->shape.size(); auto msg = tvm::tir::StringImm(ndim_err_msg.str()); asserts_.emplace_back(AssertStmt(a_ndim == v_ndim, msg, nop)); // type checks - DataType dtype = buffer->dtype; std::ostringstream type_err_msg; - type_err_msg << arg_name << ".dtype is expected to be " << dtype; + type_err_msg << arg_name << ".dtype is expected to be " << buffer->dtype; PrimExpr cond = (TVMArrayGet(DataType::UInt(8), handle, builtin::kArrTypeCode) == - IntImm(DataType::UInt(8), dtype.code()) && + IntImm(DataType::UInt(8), buffer->dtype.code()) && TVMArrayGet(DataType::UInt(8), handle, builtin::kArrTypeBits) == - IntImm(DataType::UInt(8), dtype.bits()) && + IntImm(DataType::UInt(8), buffer->dtype.bits()) && TVMArrayGet(DataType::UInt(16), handle, builtin::kArrTypeLanes) == - IntImm(DataType::UInt(16), dtype.lanes())); - if (!(dtype == DataType::Int(1) || dtype == DataType::Int(4) || dtype == DataType::UInt(4) || - dtype == DataType::UInt(16))) { + IntImm(DataType::UInt(16), buffer->dtype.lanes())); + if (!(buffer->dtype == DataType::Int(1) || buffer->dtype == DataType::Int(4) || + buffer->dtype == DataType::UInt(4) || buffer->dtype == DataType::UInt(16))) { auto type_msg = tvm::tir::StringImm(type_err_msg.str()); asserts_.emplace_back(AssertStmt(a_ndim == v_ndim, msg, nop)); asserts_.emplace_back(AssertStmt(cond, type_msg, nop)); @@ -185,27 +196,29 @@ void ArgBinder::BindDLTensor(const Buffer& buffer, const PrimExpr& device_type, IntImm(DataType::Int(32), buffer->data_alignment), nop)); } - Var v_shape(arg_name + ".shape", DataType::Handle()); + // shape field + Buffer buf_shape = decl_buffer({IntImm(DataType::Int(32), buffer->shape.size())}, tvm_shape_type, + shape_handle_name()); + Var v_shape(shape_handle_name(), DataType::Handle()); def_handle_dtype_.Set(v_shape, make_const(tvm_shape_type, 0)); init_nest_.emplace_back( - LetStmt(v_shape, TVMArrayGet(DataType::Handle(), handle, builtin::kArrShape), nop)); + LetStmt(buf_shape->data, TVMArrayGet(DataType::Handle(), handle, builtin::kArrShape), nop)); for (size_t k = 0; k < buffer->shape.size(); ++k) { - if (dtype == DataType::Int(4) || dtype == DataType::UInt(4) || dtype == DataType::Int(1)) { + if (buffer->dtype == DataType::Int(4) || buffer->dtype == DataType::UInt(4) || + buffer->dtype == DataType::Int(1)) { break; } - std::ostringstream field_name; - field_name << v_shape->name_hint << '[' << k << ']'; Bind_(buffer->shape[k], - cast(buffer->shape[k].dtype(), - Load(tvm_shape_type, v_shape, IntImm(DataType::Int(32), k), const_true(1))), - field_name.str(), true); + cast(buffer->shape[k].dtype(), BufferLoad(buf_shape, {IntImm(DataType::Int(32), k)})), + shape_element_name(k), true); } // strides field - Var v_strides(arg_name + ".strides", DataType::Handle()); - def_handle_dtype_.Set(v_strides, tir::TypeAnnotation(tvm_shape_type)); - init_nest_.emplace_back( - LetStmt(v_strides, TVMArrayGet(DataType::Handle(), handle, builtin::kArrStrides), nop)); - PrimExpr v_strides_is_null = Call(DataType::Bool(1), builtin::isnullptr(), {v_strides}); + Buffer buf_strides = decl_buffer({IntImm(DataType::Int(32), buffer->strides.size())}, + tvm_shape_type, arg_name + ".strides"); + def_handle_dtype_.Set(buf_strides->data, tir::TypeAnnotation(tvm_shape_type)); + init_nest_.emplace_back(LetStmt( + buf_strides->data, TVMArrayGet(DataType::Handle(), handle, builtin::kArrStrides), nop)); + PrimExpr v_strides_is_null = Call(DataType::Bool(1), builtin::isnullptr(), {buf_strides->data}); if (buffer->strides.size() == 0) { // Assert the buffer is compact DataType stype = buffer->DefaultIndexType(); @@ -213,14 +226,12 @@ void ArgBinder::BindDLTensor(const Buffer& buffer, const PrimExpr& device_type, Array conds; for (size_t i = buffer->shape.size(); i != 0; --i) { size_t k = i - 1; - PrimExpr svalue = - cast(stype, Load(tvm_shape_type, v_strides, IntImm(DataType::Int(32), k), const_true(1))); + PrimExpr svalue = cast(stype, BufferLoad(buf_strides, {IntImm(DataType::Int(32), k)})); conds.push_back(expect_stride == svalue); expect_stride = expect_stride * buffer->shape[k]; } std::ostringstream stride_err_msg; - stride_err_msg << arg_name << ".strides:" - << " expected to be compact array"; + stride_err_msg << stride_handle_name() << ": expected to be compact array"; if (conds.size() != 0) { auto stride_msg = tvm::tir::StringImm(stride_err_msg.str()); Stmt check = AssertStmt( @@ -235,34 +246,26 @@ void ArgBinder::BindDLTensor(const Buffer& buffer, const PrimExpr& device_type, PrimExpr stride = make_const(stype, 1); for (size_t i = buffer->shape.size(); i != 0; --i) { size_t k = i - 1; - std::ostringstream field_name; - field_name << v_strides->name_hint << '[' << k << ']'; PrimExpr value = - cast(buffer->shape[k].dtype(), - Load(tvm_shape_type, v_strides, IntImm(DataType::Int(32), k), const_true(1))); + cast(buffer->shape[k].dtype(), BufferLoad(buf_strides, {IntImm(DataType::Int(32), k)})); value = tvm::if_then_else(v_strides_is_null, stride, value); value = tvm::if_then_else(buffer->shape[k] == 1, 0, value); - Bind_(buffer->strides[k], value, field_name.str(), true); + Bind_(buffer->strides[k], value, stride_element_name(k), true); stride = analyzer_.Simplify(stride * buffer->shape[k]); } } else { PrimExpr stride_from_shape = 1; for (int k = buffer->strides.size() - 1; k >= 0; k--) { - std::ostringstream field_name; - field_name << v_strides->name_hint << '[' << k << ']'; - PrimExpr explicit_stride = - cast(buffer->shape[k].dtype(), - Load(tvm_shape_type, v_strides, IntImm(DataType::Int(32), k), const_true(1))); + cast(buffer->shape[k].dtype(), BufferLoad(buf_strides, {IntImm(DataType::Int(32), k)})); Bind_(buffer->strides[k], tvm::if_then_else(v_strides_is_null, stride_from_shape, explicit_stride), - field_name.str(), true); + stride_element_name(k), true); stride_from_shape *= - cast(buffer->shape[k].dtype(), - Load(tvm_shape_type, v_shape, IntImm(DataType::Int(32), k), const_true(1))); + cast(buffer->shape[k].dtype(), BufferLoad(buf_shape, {IntImm(DataType::Int(32), k)})); } } // Byte_offset field. diff --git a/src/tir/transforms/bf16_legalize.cc b/src/tir/transforms/bf16_legalize.cc index 79c406818185..193584f84b47 100644 --- a/src/tir/transforms/bf16_legalize.cc +++ b/src/tir/transforms/bf16_legalize.cc @@ -199,11 +199,11 @@ class BF16LowerRewriter : public StmtExprMutator { Stmt ret = StmtExprMutator::VisitStmt_(op); op = ret.as(); - auto it = buffer_remap_.find(op->buffer); - if (it != buffer_remap_.end()) { - return BufferStore(it->second, op->value, op->indices); - } else { + Buffer new_buf = GetRemappedBuffer(op->buffer); + if (new_buf.same_as(op->buffer)) { return ret; + } else { + return BufferStore(new_buf, op->value, op->indices); } } @@ -229,50 +229,34 @@ class BF16LowerRewriter : public StmtExprMutator { Stmt ret = StmtExprMutator::VisitStmt_(op); op = ret.as(); - auto it = buffer_remap_.find(op->buffer); - if (it != buffer_remap_.end()) { - return BufferRealize(it->second, op->bounds, op->condition, op->body); - } else { + Buffer new_buf = GetRemappedBuffer(op->buffer); + if (new_buf.same_as(op->buffer)) { return ret; + } else { + return BufferRealize(new_buf, op->bounds, op->condition, op->body); } } Stmt VisitStmt_(const StoreNode* op) final { - // NOTE: we do not explicit recursivly mutate op->buffer_var - Stmt ret = StmtExprMutator::VisitStmt_(op); - op = ret.as(); - - auto it = var_remap_.find(op->buffer_var); - if (it != var_remap_.end()) { - return Store(it->second, op->value, op->index, op->predicate); - } else { - return ret; - } + LOG(FATAL) << "Unexpected use of deprecated StoreNode. Please use BufferStoreNode instead."; + return Stmt(); } PrimExpr VisitExpr_(const BufferLoadNode* op) final { PrimExpr ret = StmtExprMutator::VisitExpr_(op); op = ret.as(); - auto it = buffer_remap_.find(op->buffer); - if (it != buffer_remap_.end()) { - return BufferLoad(it->second, op->indices); - } else { + Buffer new_buf = GetRemappedBuffer(op->buffer); + if (new_buf.same_as(op->buffer)) { return ret; + } else { + return BufferLoad(new_buf, op->indices); } } PrimExpr VisitExpr_(const LoadNode* op) final { - PrimExpr ret = StmtExprMutator::VisitExpr_(op); - op = ret.as(); - - if (op->dtype.is_bfloat16()) { - auto it = var_remap_.find(op->buffer_var); - ICHECK(it != var_remap_.end()) << "bfloat* var needs to be remapped"; - return Load(DataType::UInt(16, op->dtype.lanes()), it->second, op->index, op->predicate); - } else { - return ret; - } + LOG(FATAL) << "Unexpected use of deprecated LoadNode. Please use BufferLoadNode instead."; + return PrimExpr(); } PrimExpr VisitExpr_(const FloatImmNode* op) final { @@ -284,9 +268,10 @@ class BF16LowerRewriter : public StmtExprMutator { } void AlterBuffers(PrimFuncNode* op) { - std::vector> changes; + Map new_buffer_map; for (auto& itr : op->buffer_map) { + auto param_var = itr.first; auto oldbuf = itr.second; if (oldbuf->dtype.is_bfloat16()) { DataType dtype = DataType::UInt(16, oldbuf->dtype.lanes()); @@ -296,18 +281,69 @@ class BF16LowerRewriter : public StmtExprMutator { oldbuf->buffer_type); buffer_remap_[oldbuf] = newbuf; var_remap_[oldbuf->data] = buffer_var; - changes.emplace_back(itr.first, newbuf); + new_buffer_map.Set(param_var, newbuf); } else { - changes.emplace_back(itr); + new_buffer_map.Set(param_var, oldbuf); + } + } + + // Most passes do not change the preflattened buffer map, nor + // should they change it. This is an exception, because the Var + // associated with the `BufferNode::data` in + // `PrimFunc::buffer_map` may be replaced, and the corresponding + // Var in the `PrimFunc::preflattened_buffer_map` must also be + // replaced. + Map new_preflattened_buffer_map; + for (auto& itr : op->preflattened_buffer_map) { + auto param_var = itr.first; + auto oldbuf = itr.second; + if (oldbuf->dtype.is_bfloat16()) { + auto it = new_buffer_map.find(param_var); + ICHECK(it != new_buffer_map.end()) + << "PrimFunc parameter " << param_var->name_hint + << " is associated with the pre-flattened buffer " << oldbuf->name + << ", but isn't associated with any post-flatten buffer."; + const Buffer& flatbuf = (*it).second; + DataType dtype = DataType::UInt(16, oldbuf->dtype.lanes()); + auto newbuf = Buffer(flatbuf->data, dtype, oldbuf->shape, oldbuf->strides, + oldbuf->elem_offset, oldbuf->name, oldbuf->data_alignment, + oldbuf->offset_factor, oldbuf->buffer_type); + buffer_remap_[oldbuf] = newbuf; + new_preflattened_buffer_map.Set(param_var, newbuf); + } else { + new_preflattened_buffer_map.Set(param_var, oldbuf); } } if (buffer_remap_.size() != 0) { - op->buffer_map = Map(changes.begin(), changes.end()); + op->buffer_map = new_buffer_map; + op->preflattened_buffer_map = new_preflattened_buffer_map; } } private: + Buffer GetRemappedBuffer(Buffer buf) { + auto buf_it = buffer_remap_.find(buf); + if (buf_it != buffer_remap_.end()) { + return buf_it->second; + } + + Buffer new_buf = buf; + + auto var_it = var_remap_.find(buf->data); + if (var_it != var_remap_.end()) { + DataType dtype = + buf->dtype.is_bfloat16() ? DataType::UInt(16, buf->dtype.lanes()) : buf->dtype; + new_buf = Buffer(var_it->second, dtype, buf->shape, buf->strides, buf->elem_offset, buf->name, + buf->data_alignment, buf->offset_factor, buf->buffer_type, + buf->axis_separators, buf->span); + } + + buffer_remap_[buf] = new_buf; + + return new_buf; + } + std::unordered_map buffer_remap_; std::unordered_map var_remap_; }; diff --git a/src/tir/transforms/bind_params.cc b/src/tir/transforms/bind_params.cc index 944a67a879fd..1d2b2db207bd 100644 --- a/src/tir/transforms/bind_params.cc +++ b/src/tir/transforms/bind_params.cc @@ -53,12 +53,11 @@ class ParamsCollector : public StmtExprVisitor { return constant_list_; } - void VisitExpr_(const LoadNode* ln) { - if (constant_map_.find(ln->buffer_var) != constant_map_.end()) { - auto it = - std::find(constant_list_.begin(), constant_list_.end(), ln->buffer_var.operator->()); + void VisitExpr_(const BufferLoadNode* ln) { + if (constant_map_.find(ln->buffer->data) != constant_map_.end()) { + auto it = std::find(constant_list_.begin(), constant_list_.end(), ln->buffer->data.get()); if (it == constant_list_.end()) { - constant_list_.push_back(ln->buffer_var.operator->()); + constant_list_.push_back(ln->buffer->data.get()); } } StmtExprVisitor::VisitExpr_(ln); diff --git a/src/tir/transforms/bound_checker.cc b/src/tir/transforms/bound_checker.cc index 3b6af0644fc9..85aac3cee855 100644 --- a/src/tir/transforms/bound_checker.cc +++ b/src/tir/transforms/bound_checker.cc @@ -37,25 +37,30 @@ namespace tvm { namespace tir { +// TODO(Lunderberg): Move this pass to be before +// StorageFlatten/FlattenBuffer. That will simplify this pass, +// because it can check directly against the buffer limits. class BoundCollector : public StmtVisitor { public: BoundCollector() {} void VisitStmt_(const AttrStmtNode* op) final { if (op->attr_key == tir::attr::buffer_bound) { - if (const VarNode* key = op->node.as()) { - mem_to_shape[key] = op->value; + const VarNode* key = op->node.as(); + const CallNode* container = op->value.as(); + if (key && container) { + mem_to_shape[key] = container->args; } } StmtVisitor::VisitStmt_(op); } // Hashtable which maps buffer_var to shape. - std::unordered_map mem_to_shape; + std::unordered_map> mem_to_shape; }; class BoundChecker : public StmtExprMutator { public: - explicit BoundChecker(const std::unordered_map& mem_to_shape) + explicit BoundChecker(const std::unordered_map>& mem_to_shape) : mem_to_shape_(mem_to_shape) {} Stmt VisitStmt_(const AllocateNode* op) final { @@ -73,21 +78,31 @@ class BoundChecker : public StmtExprMutator { return StmtExprMutator::VisitExpr_(op); } + PrimExpr VisitExpr_(const LoadNode* op) final { + LOG(FATAL) << "Unexpected use of deprecated LoadNode. Please use BufferLoadNode instead."; + return PrimExpr(); + } + Stmt VisitStmt_(const StoreNode* op) final { + LOG(FATAL) << "Unexpected use of deprecated StoreNode. Please use BufferStoreNode instead."; + return Stmt(); + } + + Stmt VisitStmt_(const BufferStoreNode* op) final { store_scope_bound_collector_.clear(); process_store_ = true; unsafe_rewritten_ = false; StmtExprMutator::VisitStmt_(op); process_store_ = false; - if (CanInstrument(op->index, op->buffer_var)) { - Collect(op->index, op->buffer_var); + if (CanInstrument(op->indices, op->buffer->data)) { + Collect(op->indices, op->buffer->data); } // The collector should has at least one item. if (store_scope_bound_collector_.size()) { PrimExpr condition = MakeCondition(); if (!condition.as()) { Stmt nop = Evaluate(1); - Stmt then_case = Store(op->buffer_var, op->value, op->index, op->predicate); + Stmt then_case = GetRef(op); Stmt else_case = AssertStmt(condition, StringImm(error_message_), nop); Stmt body = IfThenElse(condition, then_case, else_case); return body; @@ -96,9 +111,9 @@ class BoundChecker : public StmtExprMutator { return GetRef(op); } - PrimExpr VisitExpr_(const LoadNode* op) final { - if (CanInstrument(op->index, op->buffer_var)) { - Collect(op->index, op->buffer_var); + PrimExpr VisitExpr_(const BufferLoadNode* op) final { + if (CanInstrument(op->indices, op->buffer->data)) { + Collect(op->indices, op->buffer->data); } return StmtExprMutator::VisitExpr_(op); } @@ -108,79 +123,106 @@ class BoundChecker : public StmtExprMutator { return (buffer_var.defined() && mem_to_shape_.count(buffer_var.get())); } - void Update(const Var& buffer_var, const Array& new_shape, const DataType& type) { + void Update(const Var& buffer_var, Array new_shape, const DataType& type) { // Sanity check at first. - if (!new_shape.size()) { + if (!ShapeIsValid(new_shape)) { return; } - for (size_t i = 0; i < new_shape.size(); ++i) { - if (!new_shape[0].defined() || !new_shape[i].dtype().is_scalar() || - is_negative_const(new_shape[i])) { - return; + new_shape.MutateByApply([&](const PrimExpr& dim) { + // Cast to uint64 to avoid potential overflow. + return make_const(DataType::UInt(64), type.lanes()) * dim; + }); + mem_to_shape_[buffer_var.get()] = new_shape; + } + + bool ShapeIsValid(const Array& shape) const { + if (!shape.defined()) { + return false; + } + for (const auto& dim : shape) { + if (!IsValidScalar(dim) || is_negative_const(dim)) { + return false; } } - // Scalarize the shape. - PrimExpr shape = - Mul(make_const(DataType::UInt(64), type.lanes()), Cast(DataType::UInt(64), new_shape[0])); - for (size_t i = 1; i < new_shape.size(); ++i) { - // Cast to unsigned to avoid integer overlow at frist. - shape = Mul(shape, Mul(make_const(DataType::UInt(64), type.lanes()), - Cast(DataType::UInt(64), new_shape[i]))); - } - mem_to_shape_[buffer_var.get()] = shape; + return true; } - bool IndexIsValid(const PrimExpr& index) const { - if (!index.defined()) { + bool IndicesAreValid(const Array& indices) const { + if (!indices.defined()) { return false; } - if (const RampNode* ramp_index = index.as()) { - return ramp_index->base.defined() && ramp_index->base.dtype().is_scalar() && - ramp_index->stride.defined() && ramp_index->stride.dtype().is_scalar() && - (ramp_index->lanes > 0); + for (const auto& index : indices) { + if (!index.defined()) { + return false; + } + + if (const RampNode* ramp_index = index.as()) { + if (!IsValidScalar(ramp_index->base)) { + return false; + } + if (!IsValidScalar(ramp_index->stride)) { + return false; + } + if (ramp_index->lanes <= 0) { + return false; + } + } } return true; } - bool CanInstrument(const PrimExpr& index, const Var& buffer_var) const { - return buffer_var.defined() && mem_to_shape_.count(buffer_var.get()) && IndexIsValid(index) && - !unsafe_rewritten_; + bool IsValidScalar(const PrimExpr& expr) const { + return expr.defined() && expr.dtype().is_scalar(); + } + + bool CanInstrument(const Array& indices, const Var& buffer_var) const { + return buffer_var.defined() && mem_to_shape_.count(buffer_var.get()) && + IndicesAreValid(indices) && !unsafe_rewritten_; } - void Collect(PrimExpr index, Var buffer_var) { - store_scope_bound_collector_.push_back(std::make_pair(index, mem_to_shape_[buffer_var.get()])); + void Collect(Array indices, Var buffer_var) { + store_scope_bound_collector_.push_back( + std::make_pair(indices, mem_to_shape_[buffer_var.get()])); } PrimExpr MakeCondition() { PrimExpr condition; - for (size_t i = 0; i < store_scope_bound_collector_.size(); ++i) { - std::pair buffer_to_mem = store_scope_bound_collector_[i]; - PrimExpr index = buffer_to_mem.first; - PrimExpr upper_bound = buffer_to_mem.second; - - if (const RampNode* ramp_index = index.as()) { - // In case index is base + stride * i. - // Non inclusive range. - index = Add(ramp_index->base, Mul(ramp_index->stride, make_const(ramp_index->stride.dtype(), - ramp_index->lanes - 1))); + for (const auto& pair : store_scope_bound_collector_) { + Array indices = pair.first; + Array shape = pair.second; + + ICHECK_EQ(indices.size(), shape.size()) + << "Mismatch between dimension of physical shape and physical indices"; + + for (size_t i = 0; i < indices.size(); i++) { + PrimExpr index = indices[i]; + PrimExpr upper_bound = shape[i]; + + if (const RampNode* ramp_index = index.as()) { + // In case index is base + stride * i. + // Non inclusive range. + index = Add(ramp_index->base, + Mul(ramp_index->stride, + make_const(ramp_index->stride.dtype(), ramp_index->lanes - 1))); + } + + // Try to simplify index and bound. + index = analyzer_.Simplify(index); + upper_bound = analyzer_.Simplify(upper_bound); + + // Cast to the same type - signed, to be able to check lower bound. + index = Cast(DataType::Int(64), index); + upper_bound = Cast(DataType::Int(64), upper_bound); + + // Looks like a lower bound should always be zero after normalization. + PrimExpr lower_bound = make_zero(DataType::Int(64)); + + PrimExpr current_condition = And(GE(index, lower_bound), LT(index, upper_bound)); + condition = condition.defined() ? And(condition, current_condition) : current_condition; } - - // Try to simplify index and bound. - index = analyzer_.Simplify(index); - upper_bound = analyzer_.Simplify(upper_bound); - - // Cast to the same type - signed, to be able to check lower bound. - index = Cast(DataType::Int(64), index); - upper_bound = Cast(DataType::Int(64), upper_bound); - - // Looks like a lower bound should always be zero after normalization. - PrimExpr lower_bound = make_zero(DataType::Int(64)); - - PrimExpr current_condition = And(GE(index, lower_bound), LT(index, upper_bound)); - condition = !i ? current_condition : And(condition, current_condition); } return condition; } @@ -190,11 +232,11 @@ class BoundChecker : public StmtExprMutator { // Whether we face tvm_if_then_else intrinsic. bool unsafe_rewritten_{false}; // Pool which collects the pair of index and shape for specific store/load. - std::vector> store_scope_bound_collector_; + std::vector, Array>> store_scope_bound_collector_; // Error message. const char* const error_message_ = "OUT OF THE BOUNDS"; // Hashtable which maps buffer_var to shape. - std::unordered_map mem_to_shape_; + std::unordered_map> mem_to_shape_; // internal analyzer arith::Analyzer analyzer_; }; diff --git a/src/tir/transforms/compact_buffer_region.cc b/src/tir/transforms/compact_buffer_region.cc index 20ddd7f84a35..6a317397d6ea 100644 --- a/src/tir/transforms/compact_buffer_region.cc +++ b/src/tir/transforms/compact_buffer_region.cc @@ -99,13 +99,11 @@ class BufferAccessRegionCollector : public StmtExprVisitor { void VisitExpr_(const VarNode* op) final { VisitBufferVar(GetRef(op)); } void VisitExpr_(const LoadNode* op) final { - StmtExprVisitor::VisitExpr_(op); - VisitBufferVar(op->buffer_var); + LOG(FATAL) << "Unexpected use of deprecated LoadNode. Please use BufferLoadNode instead."; } void VisitStmt_(const StoreNode* op) final { - StmtExprVisitor::VisitStmt_(op); - VisitBufferVar(op->buffer_var); + LOG(FATAL) << "Unexpected use of deprecated StoreNode. Please use BufferStoreNode instead."; } void VisitStmt_(const ForNode* op) final { @@ -217,7 +215,8 @@ class BufferAccessRegionCollector : public StmtExprVisitor { continue; } auto dom_it = dom_map_.find(v); - ICHECK(dom_it != dom_map_.end()); + ICHECK(dom_it != dom_map_.end()) + << "Could not find domain for loop variable " << v->name_hint; non_relaxed[i] = dom_it->second; dom_map_.erase(dom_it); } diff --git a/src/tir/transforms/coproc_sync.cc b/src/tir/transforms/coproc_sync.cc index 7a6d2d37c376..f3a9f990599f 100644 --- a/src/tir/transforms/coproc_sync.cc +++ b/src/tir/transforms/coproc_sync.cc @@ -39,18 +39,24 @@ namespace tir { class CoProcTouchedBuffer : public StmtExprVisitor { public: void VisitExpr_(const LoadNode* op) final { + LOG(FATAL) << "Unexpected use of deprecated LoadNode. Please use BufferLoadNode instead."; + } + void VisitStmt_(const StoreNode* op) final { + LOG(FATAL) << "Unexpected use of deprecated StoreNode. Please use BufferStoreNode instead."; + } + void VisitExpr_(const BufferLoadNode* op) final { if (in_scope_) { - touched_[op->buffer_var.get()].coproc = true; + touched_[op->buffer->data.get()].coproc = true; } else { - touched_[op->buffer_var.get()].normal = true; + touched_[op->buffer->data.get()].normal = true; } StmtExprVisitor::VisitExpr_(op); } - void VisitStmt_(const StoreNode* op) final { + void VisitStmt_(const BufferStoreNode* op) final { if (in_scope_) { - touched_[op->buffer_var.get()].coproc = true; + touched_[op->buffer->data.get()].coproc = true; } else { - touched_[op->buffer_var.get()].normal = true; + touched_[op->buffer->data.get()].normal = true; } StmtExprVisitor::VisitStmt_(op); } @@ -325,7 +331,8 @@ class CoProcBarrierDetector : public StorageAccessVisitor { Array wset; for (const AccessEntry& acc : wvec) { ICHECK(acc.dtype == wvec[0].dtype); - wset.push_back(acc.touched); + ICHECK_EQ(acc.touched.size(), 1) << "CoProcBarrierDetector expects flat memory"; + wset.push_back(acc.touched[0]); } Range none; Range r = arith::Union(wset).CoverRange(none); diff --git a/src/tir/transforms/flatten_buffer.cc b/src/tir/transforms/flatten_buffer.cc index e9d99cda7e13..c7cc51d27113 100644 --- a/src/tir/transforms/flatten_buffer.cc +++ b/src/tir/transforms/flatten_buffer.cc @@ -46,13 +46,30 @@ PrimExpr BufferArea(const Buffer& buffer) { } /*! - * \brief Transform multi-dimension BufferLoad/BufferStore into one-dimension Load/Store + * \brief Transform multi-dimension BufferLoad/BufferStore into device-supported dimension */ class BufferFlattener : public StmtExprMutator { public: - static Stmt Flatten(const PrimFunc& f) { return BufferFlattener().VisitStmt(f->body); } + static PrimFunc Flatten(PrimFunc func) { + Map preflattened_buffer_map = + Merge(func->buffer_map, func->preflattened_buffer_map); + + auto pass = BufferFlattener(func->buffer_map); + + auto writer = func.CopyOnWrite(); + writer->body = pass.VisitStmt(func->body); + writer->preflattened_buffer_map = preflattened_buffer_map; + writer->buffer_map = pass.updated_extern_buffer_map_; + return func; + } private: + explicit BufferFlattener(const Map& extern_buffer_map) { + for (const auto& kv : extern_buffer_map) { + updated_extern_buffer_map_.Set(kv.first, GetFlattenedBuffer(kv.second)); + } + } + Stmt VisitStmt_(const BlockRealizeNode* op) final { // We have convert blocks into opaque blocks in previous passes. ICHECK(op->iter_values.empty()) << "Non-opaque blocks are not allowed in FlattenBuffer. Please " @@ -67,8 +84,8 @@ class BufferFlattener : public StmtExprMutator { } // Step 3. Handle allocations in reverse order for (size_t i = new_block->alloc_buffers.size(); i > 0; --i) { - const Buffer& buffer = new_block->alloc_buffers[i - 1]; - body = MakeAllocStmt(buffer, std::move(body)); + Buffer buffer = GetFlattenedBuffer(new_block->alloc_buffers[i - 1]); + body = Allocate(buffer->data, buffer->dtype, buffer->shape, const_true(), std::move(body)); } return body; } @@ -112,11 +129,6 @@ class BufferFlattener : public StmtExprMutator { return body; } - Stmt VisitStmt_(const BufferStoreNode* op) final { - BufferStore store = Downcast(StmtExprMutator::VisitStmt_(op)); - return store->buffer.vstore(store->indices, store->value); - } - PrimExpr VisitExpr_(const VarNode* op) final { Var var = GetRef(op); auto it = unit_loop_vars_.find(var); @@ -131,16 +143,69 @@ class BufferFlattener : public StmtExprMutator { } } + Buffer GetFlattenedBuffer(Buffer buf) { + auto it = buffer_remap_.find(buf); + if (it != buffer_remap_.end()) { + return it->second; + } + + auto flattened = buf.GetFlattenedBuffer(); + + // TODO(Lunderberg): Move the handling of boolean into a + // dedicated pass. + if (flattened->dtype == DataType::Bool()) { + auto writer = flattened.CopyOnWrite(); + writer->dtype = DataType::Int(8); + } + + buffer_remap_[buf] = flattened; + return flattened; + } + + Stmt VisitStmt_(const BufferStoreNode* op) final { + BufferStore store = Downcast(StmtExprMutator::VisitStmt_(op)); + + // Handle casts from the value's dtype to the dtype of the + // backing array. + // TODO(Lunderberg): Move the handling of boolean into a + // dedicated pass. + if (store->value.dtype() == DataType::Bool()) { + ICHECK_EQ(store->buffer->dtype, DataType::Int(8)) + << "Expected int8 backing array for boolean tensor"; + auto writer = store.CopyOnWrite(); + writer->value = tir::Cast(DataType::Int(8), store->value); + } + auto flattened_indices = store->buffer->ElemOffset(store->indices); + return VisitBufferAccess(std::move(store)); + } + PrimExpr VisitExpr_(const BufferLoadNode* op) final { + bool load_returns_bool = (op->dtype == DataType::Bool()); BufferLoad load = Downcast(StmtExprMutator::VisitExpr_(op)); - return load->buffer.vload(load->indices, load->dtype); + load = VisitBufferAccess(load); + + // Handle casts from dtype of the backing array to value's dtype. + // TODO(Lunderberg): Move the handling of boolean into a + // dedicated pass. + if (load_returns_bool) { + ICHECK_EQ(load->buffer->dtype, DataType::Int(8)) + << "Expected int8 backing array for boolean tensor"; + return tir::Cast(DataType::Bool(), load); + } else { + return std::move(load); + } } - static Stmt MakeAllocStmt(const Buffer& buffer, Stmt body) { - String storage_scope = buffer.scope(); - PrimExpr area = BufferArea(buffer); - body = Allocate(buffer->data, buffer->dtype, {area}, const_true(), std::move(body)); - return body; + template + Node VisitBufferAccess(Node node) { + ICHECK(node->buffer.defined()); + auto flattened_indices = node->buffer->ElemOffset(node->indices); + Buffer flattened_buffer = GetFlattenedBuffer(node->buffer); + + auto writer = node.CopyOnWrite(); + writer->buffer = flattened_buffer; + writer->indices = flattened_indices; + return node; } static Stmt MakeLaunchThread(PrimExpr min, PrimExpr extent, Var var, String thread_tag, @@ -176,14 +241,18 @@ class BufferFlattener : public StmtExprMutator { /*! \brief Record the loop_var and loop start value of unit loops, whose extent is one. */ std::unordered_map unit_loop_vars_; + + /*! \brief Map of buffers being remapped. */ + std::unordered_map buffer_remap_; + + /*! \brief The updated external buffer map. */ + Map updated_extern_buffer_map_; }; PrimFunc FlattenBuffer(PrimFunc f) { // Only apply this pass to TIR that is not from TE schedules if (!IsFromLegacyTESchedule(f)) { - PrimFuncNode* fptr = f.CopyOnWrite(); - fptr->body = BufferFlattener::Flatten(f); - return f; + return BufferFlattener::Flatten(f); } else { return f; } diff --git a/src/tir/transforms/inject_copy_intrin.cc b/src/tir/transforms/inject_copy_intrin.cc index 9e74b8cd1fdb..81842ff808ff 100644 --- a/src/tir/transforms/inject_copy_intrin.cc +++ b/src/tir/transforms/inject_copy_intrin.cc @@ -69,9 +69,9 @@ class CopyIntrinInjector : public StmtMutator { loops.push_back(op); body = op->body; } - const StoreNode* store = body.as(); + auto store = body.as(); if (store == nullptr) { - *error_info = "the 'StoreNode' of body is a nullptr."; + *error_info = "the body is not a 'BufferStoreNode'"; return false; } // Expr sel_cond, sel_true_value, sel_false_value; @@ -81,17 +81,17 @@ class CopyIntrinInjector : public StmtMutator { select(sel_cond, sel_true_value, sel_false_value).Match(store->value); const CastNode* cast = store->value.as(); - const LoadNode* load = store->value.as(); + auto load = store->value.as(); if (0 == loops.size()) { ICHECK(!has_cond); } // for now only support true condition matching if (has_cond) { - load = sel_true_value.Eval().as(); + load = sel_true_value.Eval().as(); } // cast can be part of the pattern if (cast != nullptr) { - load = cast->value.as(); + load = cast->value.as(); } if (load == nullptr) { *error_info = "the 'LoadNode' of body is a nullptr."; @@ -102,8 +102,17 @@ class CopyIntrinInjector : public StmtMutator { for (const ForNode* op : loops) { loop_vars.push_back(op->loop_var); } - Array store_strides = arith::DetectLinearEquation(store->index, loop_vars); - Array load_strides = arith::DetectLinearEquation(load->index, loop_vars); + // TODO(Lunderberg): Move this pass to be before + // StorageFlatten/FlattenBuffer. That will simplify the + // implementation, since the pre-flattened indices/strides can be + // used directly. + ICHECK((store->indices.size() == 1) && (load->indices.size() == 1)) + << "InjectDoubleBuffer expects flat 1-d buffers. " + << "Has StorageFlatten (TE-based schedules) or " + << "FlattenBuffer (TIR-based schedules) been run?"; + + Array store_strides = arith::DetectLinearEquation(store->indices[0], loop_vars); + Array load_strides = arith::DetectLinearEquation(load->indices[0], loop_vars); if (load_strides.size() == 0 || store_strides.size() == 0) return false; Array dst_shape; const size_t loop_var_size = loop_vars.size(); @@ -160,10 +169,21 @@ class CopyIntrinInjector : public StmtMutator { src_strides.push_back(make_const(DataType::Int(32), 1)); dst_strides.push_back(make_const(DataType::Int(32), 1)); } - Buffer dst = Buffer(store->buffer_var, store->value.dtype(), dst_shape, dst_strides, - store_strides[loop_var_size], store->buffer_var->name_hint, 0, 0, kDefault); - Buffer src = Buffer(load->buffer_var, load->dtype, src_shape, src_strides, src_elem_offset, - load->buffer_var->name_hint, 0, 0, kDefault); + Buffer dst = store->buffer; + { + auto writer = dst.CopyOnWrite(); + writer->shape = dst_shape; + writer->strides = dst_strides; + writer->elem_offset = store_strides[loop_var_size]; + } + + Buffer src = load->buffer; + { + auto writer = src.CopyOnWrite(); + writer->shape = src_shape; + writer->strides = src_strides; + writer->elem_offset = src_elem_offset; + } *out = flower_copy_fromto_(src, dst, pad_before, pad_after, pad_value); if (!out->defined()) { *error_info = "flower function did not return correct stmt"; diff --git a/src/tir/transforms/inject_double_buffer.cc b/src/tir/transforms/inject_double_buffer.cc index 0b45bde28dfe..03f2ccd40dd1 100644 --- a/src/tir/transforms/inject_double_buffer.cc +++ b/src/tir/transforms/inject_double_buffer.cc @@ -107,15 +107,15 @@ class DoubleBufferInjector : public StmtExprMutator { auto it = dbuffer_info_.find(buf); if (it != dbuffer_info_.end()) { it->second.scope = GetPtrStorageScope(op->buffer_var); - it->second.stride = foldl([](PrimExpr a, PrimExpr b, Span span) { return mul(a, b, span); }, - make_const(DataType::Int(32), 1), op->extents) * - op->dtype.lanes(); + + ICHECK_EQ(op->extents.size(), 1) << "InjectDoubleBuffer expects flat 1-d buffers. " + << "Has StorageFlatten (TE-based schedules) or " + << "FlattenBuffer (TIR-based schedules) been run?"; + it->second.stride = op->extents[0]; Stmt stmt = StmtExprMutator::VisitStmt_(op); op = stmt.as(); - Array new_extents{make_const(op->extents[0].dtype(), 2)}; - for (PrimExpr e : op->extents) { - new_extents.push_back(e); - } + + Array new_extents = {op->extents[0] * make_const(op->extents[0].dtype(), 2)}; ICHECK(it->second.loop != nullptr); auto& alloc_nest = loop_allocs_[it->second.loop]; alloc_nest.emplace_back( @@ -170,34 +170,77 @@ class DoubleBufferInjector : public StmtExprMutator { return stmt; } + PrimExpr VisitExpr_(const LoadNode* op) final { + LOG(FATAL) << "Unexpected use of deprecated LoadNode. Please use BufferLoadNode instead."; + return PrimExpr(); + } + Stmt VisitStmt_(const StoreNode* op) final { - Stmt stmt = StmtExprMutator::VisitStmt_(op); - op = stmt.as(); - auto it = dbuffer_info_.find(op->buffer_var.get()); + LOG(FATAL) << "Unexpected use of deprecated StoreNode. Please use BufferStoreNode instead."; + return Stmt(); + } + + Stmt VisitStmt_(const BufferStoreNode* op) final { + auto node = Downcast(StmtExprMutator::VisitStmt_(op)); + + auto it = dbuffer_info_.find(node->buffer->data.get()); if (it != dbuffer_info_.end()) { const StorageEntry& e = it->second; ICHECK(in_double_buffer_scope_); - ICHECK(e.stride.defined()); - return Store(op->buffer_var, op->value, e.switch_write_var * e.stride + op->index, - op->predicate); - } else { - return stmt; + ICHECK(e.switch_write_var.defined()); + + ICHECK_EQ(node->indices.size(), 1) << "InjectDoubleBuffer expects flat 1-d buffers. " + << "Has StorageFlatten (TE-based schedules) or " + << "FlattenBuffer (TIR-based schedules) been run?"; + + auto writer = node.CopyOnWrite(); + writer->buffer = GetRemappedBuffer(node->buffer, e.stride); + writer->indices = {e.switch_write_var * e.stride + node->indices[0]}; } + + return std::move(node); } - PrimExpr VisitExpr_(const LoadNode* op) final { - PrimExpr expr = StmtExprMutator::VisitExpr_(op); - op = expr.as(); - auto it = dbuffer_info_.find(op->buffer_var.get()); + PrimExpr VisitExpr_(const BufferLoadNode* op) final { + auto node = Downcast(StmtExprMutator::VisitExpr_(op)); + + auto it = dbuffer_info_.find(node->buffer->data.get()); if (it != dbuffer_info_.end()) { const StorageEntry& e = it->second; - ICHECK(e.stride.defined()); ICHECK(e.switch_read_var.defined()); - return Load(op->dtype, op->buffer_var, e.switch_read_var * e.stride + op->index, - op->predicate); - } else { - return expr; + + ICHECK_EQ(node->indices.size(), 1) << "InjectDoubleBuffer expects flat 1-d buffers. " + << "Has StorageFlatten (TE-based schedules) or " + << "FlattenBuffer (TIR-based schedules) been run?"; + + auto writer = node.CopyOnWrite(); + writer->buffer = GetRemappedBuffer(node->buffer, e.stride); + writer->indices = {e.switch_read_var * e.stride + node->indices[0]}; } + + return std::move(node); + } + + Buffer GetRemappedBuffer(Buffer buf, PrimExpr stride) { + auto key = buf.get(); + auto it = buf_remap_.find(key); + if (it != buf_remap_.end()) { + return it->second; + } + + ICHECK(stride.defined()); + // TODO(Lunderberg): Move this pass to before + // StorageFlatten/FlattenBuffer. That will simplify the + // implementation, to be the insertion of a new dimension for the + // buffer, rather than adjusting the other indices. + ICHECK_EQ(buf->shape.size(), 1) << "InjectDoubleBuffer expects flat 1-d buffers. " + << "Has StorageFlatten (TE-based schedules) or " + << "FlattenBuffer (TIR-based schedules) been run?"; + auto writer = buf.CopyOnWrite(); + writer->shape = {buf->shape[0] * stride}; + + buf_remap_[key] = buf; + return buf; } PrimExpr VisitExpr_(const VarNode* op) final { @@ -261,6 +304,8 @@ class DoubleBufferInjector : public StmtExprMutator { std::unordered_map > loop_pre_; // The allocation size of the buffer std::unordered_map dbuffer_info_; + // The updated Buffer objects + std::unordered_map buf_remap_; }; namespace transform { diff --git a/src/tir/transforms/inject_virtual_thread.cc b/src/tir/transforms/inject_virtual_thread.cc index 4964bec0334e..f6ce88cf1707 100644 --- a/src/tir/transforms/inject_virtual_thread.cc +++ b/src/tir/transforms/inject_virtual_thread.cc @@ -50,7 +50,10 @@ class ExprTouched final : public StmtExprVisitor { StmtExprVisitor::VisitStmt(n); } void VisitExpr_(const LoadNode* op) final { - HandleUseVar(op->buffer_var.get()); + LOG(FATAL) << "Unexpected use of deprecated StoreNode. Please use BufferStoreNode instead."; + } + void VisitExpr_(const BufferLoadNode* op) final { + HandleUseVar(op->buffer->data.get()); StmtExprVisitor::VisitExpr_(op); } void VisitExpr_(const VarNode* op) final { HandleUseVar(op); } @@ -101,11 +104,18 @@ class VarTouchedAnalysis : public StmtVisitor { Record(op->var.get(), tc); this->VisitStmt(op->body); } + void VisitStmt_(const StoreNode* op) final { + LOG(FATAL) << "Unexpected use of deprecated StoreNode. Please use BufferStoreNode instead."; + } + + void VisitStmt_(const BufferStoreNode* op) final { ExprTouched tc(touched_var_, false); tc(op->value); - tc(op->index); - Record(op->buffer_var.get(), tc); + for (const auto& index : op->indices) { + tc(index); + } + Record(op->buffer->data.get(), tc); } void VisitStmt_(const ForNode* op) final { ExprTouched tc(touched_var_, false); @@ -204,20 +214,6 @@ class VTInjector : public StmtExprMutator { PrimExpr RewriteIndex(PrimExpr index, PrimExpr alloc_extent) const { return index + var_ * alloc_extent; } - // Load - PrimExpr VisitExpr_(const LoadNode* op) final { - PrimExpr expr = StmtExprMutator::VisitExpr_(op); - op = expr.as(); - if (touched_var_.count(op->buffer_var.get())) { - visit_touched_var_ = true; - } - auto it = alloc_remap_.find(op->buffer_var.get()); - if (it != alloc_remap_.end()) { - return Load(op->dtype, op->buffer_var, RewriteIndex(op->index, it->second), op->predicate); - } else { - return expr; - } - } // Expression. PrimExpr VisitExpr_(const CallNode* op) final { if (op->op.same_as(builtin::tvm_access_ptr())) { @@ -230,7 +226,8 @@ class VTInjector : public StmtExprMutator { PrimExpr offset = this->VisitExpr(op->args[2]); PrimExpr extent = this->VisitExpr(op->args[3]); PrimExpr stride = it->second / make_const(offset.dtype(), dtype.lanes()); - offset = stride * var_ + offset; + offset = RewriteIndex(offset, stride); + return Call(op->dtype, op->op, {op->args[0], op->args[1], offset, extent, op->args[4]}); } else if (op->op.same_as(builtin::tvm_context_id())) { return allow_share_ ? GetRef(op) : var_; @@ -242,21 +239,61 @@ class VTInjector : public StmtExprMutator { trigger_base_inject_ = !allow_share_; return StmtExprMutator::VisitStmt_(op); } + // Load + PrimExpr VisitExpr_(const LoadNode* op) final { + LOG(FATAL) << "Unexpected use of deprecated LoadNode. Please use BufferLoadNode instead."; + return PrimExpr(); + } // Store Stmt VisitStmt_(const StoreNode* op) final { - Stmt stmt = StmtExprMutator::VisitStmt_(op); - op = stmt.as(); - if (touched_var_.count(op->buffer_var.get())) { + LOG(FATAL) << "Unexpected use of deprecated StoreNode. Please use BufferStoreNode instead."; + return Stmt(); + } + // BufferLoad + PrimExpr VisitExpr_(const BufferLoadNode* op) final { + auto node = Downcast(StmtExprMutator::VisitExpr_(op)); + return VisitBufferAccess(std::move(node)); + } + // BufferStore + Stmt VisitStmt_(const BufferStoreNode* op) final { + auto node = Downcast(StmtExprMutator::VisitStmt_(op)); + trigger_base_inject_ = !allow_share_; + return VisitBufferAccess(std::move(node)); + } + + template + Node VisitBufferAccess(Node node) { + if (touched_var_.count(node->buffer->data.get())) { visit_touched_var_ = true; } - trigger_base_inject_ = !allow_share_; - auto it = alloc_remap_.find(op->buffer_var.get()); + + auto it = alloc_remap_.find(node->buffer->data.get()); if (it != alloc_remap_.end()) { - return Store(op->buffer_var, op->value, RewriteIndex(op->index, it->second), op->predicate); - } else { - return stmt; + ICHECK_EQ(node->indices.size(), 1) + << "InjectVirtualThread expects rewritten allocations to be flat memory."; + auto writer = node.CopyOnWrite(); + writer->buffer = GetRemappedBuffer(node->buffer, it->second); + writer->indices = {RewriteIndex(node->indices[0], it->second)}; + } + + return node; + } + + Buffer GetRemappedBuffer(Buffer buf, PrimExpr alloc_extent) { + auto key = buf.get(); + auto it = buf_remap_.find(key); + if (it != buf_remap_.end()) { + return it->second; } + + ICHECK_EQ(buf->shape.size(), 1) << "Expected buffers being rewritten to already be flattened."; + auto writer = buf.CopyOnWrite(); + writer->shape = {buf->shape[0] * alloc_extent}; + + buf_remap_[key] = buf; + return buf; } + // Attribute Stmt VisitStmt_(const AttrStmtNode* op) final { PrimExpr value = this->VisitExpr(op->value); @@ -354,46 +391,44 @@ class VTInjector : public StmtExprMutator { } // Allocate Stmt VisitStmt_(const AllocateNode* op) final { + Allocate node = GetRef(op); + PrimExpr condition = this->VisitExpr(op->condition); + + Array extents = op->extents; + extents.MutateByApply([this](const PrimExpr& extent) { return this->VisitExpr(extent); }); + if (visit_touched_var_ && !vt_loop_injected_) { return InjectVTLoop(GetRef(op), true); } - bool changed = false; - Array extents; - for (size_t i = 0; i < op->extents.size(); i++) { - PrimExpr new_ext = this->VisitExpr(op->extents[i]); - if (visit_touched_var_ && !vt_loop_injected_) { - return InjectVTLoop(GetRef(op), true); - } - if (!new_ext.same_as(op->extents[i])) changed = true; - extents.push_back(new_ext); - } visit_touched_var_ = false; - Stmt body; - // always rewrite if not allow sharing. + // Rewrite the buffer if its shape or any value stored in it + // depends on the virtual thread var. If `allow_share_` is false, + // then the buffer is always rewritten, even if separate virtual + // threads only read from the buffer. if (touched_var_.count(op->buffer_var.get()) || !allow_share_) { // place v on highest dimension. - PrimExpr stride = foldl([](PrimExpr a, PrimExpr b, Span span) { return mul(a, b, span); }, - make_const(DataType::Int(32), 1), op->extents) * - op->dtype.lanes(); - Array other; - other.push_back(make_const(op->extents[0].dtype(), num_threads_)); - for (PrimExpr e : extents) { - other.push_back(e); - } - extents = other; - changed = true; - // mark this buffer get touched. + + // TODO(Lunderberg): Move pass to apply before + // StorageFlatten/FlattenBuffer. Would rewrite the Buffer to + // add the injected virtual thread as the first index. + ICHECK_EQ(extents.size(), 1) + << "InjectVirtualThread expects rewritten allocations to be flat memory."; + PrimExpr stride = extents[0]; + extents = {stride * num_threads_}; + + // Mark the buffer var as touched. BufferLoad/BufferStore should + // access locations at `current_index + stride*vthread_var`. alloc_remap_[op->buffer_var.get()] = stride; - // Mutate the body. - body = this->VisitStmt(op->body); - } else { - // Mutate the body. - body = this->VisitStmt(op->body); } - if (!changed && body.same_as(op->body) && condition.same_as(op->condition)) { + + // Mutate the body. Depends on alloc_remap_. + auto body = this->VisitStmt(op->body); + + if (extents.same_as(op->extents) && body.same_as(op->body) && + condition.same_as(op->condition)) { return GetRef(op); } else { return Allocate(op->buffer_var, op->dtype, extents, condition, body); @@ -448,8 +483,21 @@ class VTInjector : public StmtExprMutator { const std::unordered_set& touched_var_; // Whether allow shareding. bool allow_share_; - // The allocations that get touched -> extent + /* \brief The allocations that get touched -> extent + * + * Maps from the buffer_var of an allocate node to the original + * extent of the allocation. Used when rewriting the indices of + * BufferLoad/BufferStore. + */ std::unordered_map alloc_remap_; + /*! \brief Map of buffers that are modified. + * + * Buffers allocated or written to within the virtual thread loop + * must have one copy per virtual thread. This is done by enlarging + * the allocated buffer size, then modifying the indices at which + * each virtual thread accesses the buffer. + */ + std::unordered_map buf_remap_; }; class VirtualThreadInjector : public StmtMutator { diff --git a/src/tir/transforms/ir_utils.cc b/src/tir/transforms/ir_utils.cc index 4eb9cc5b1a90..700c9931bba0 100644 --- a/src/tir/transforms/ir_utils.cc +++ b/src/tir/transforms/ir_utils.cc @@ -101,45 +101,89 @@ class IRConvertSSA final : public StmtExprMutator { const Var& v = op->var; if (defined_.count(v.get())) { PrimExpr value = this->VisitExpr(op->value); - Var new_var(v->name_hint, v.dtype()); - scope_[v.get()].push_back(new_var); + ScopedRedefine redefine(this, v); PrimExpr body = this->VisitExpr(op->body); - scope_[v.get()].pop_back(); - return Let(new_var, value, body); + return Let(redefine.new_var, value, body); } else { defined_.insert(v.get()); return StmtExprMutator::VisitExpr_(op); } } + PrimExpr VisitExpr_(const LoadNode* op) final { - PrimExpr expr = StmtExprMutator::VisitExpr_(op); - op = expr.as(); - const VarNode* v = op->buffer_var.get(); - if (scope_.count(v) && !scope_[v].empty()) { - return Load(op->dtype, scope_[v].back(), op->index, op->predicate); - } else { - return expr; - } + LOG(FATAL) << "Unexpected use of deprecated LoadNode. Please use BufferLoadNode instead."; + return PrimExpr(); } + Stmt VisitStmt_(const StoreNode* op) final { - Stmt stmt = StmtExprMutator::VisitStmt_(op); - op = stmt.as(); - const VarNode* v = op->buffer_var.get(); - if (scope_.count(v) && !scope_[v].empty()) { - return Store(scope_[v].back(), op->value, op->index, op->predicate); - } else { - return stmt; + LOG(FATAL) << "Unexpected use of deprecated StoreNode. Please use BufferStoreNode instead."; + return Stmt(); + } + + PrimExpr VisitExpr_(const BufferLoadNode* op) final { + auto node = Downcast(StmtExprMutator::VisitExpr_(op)); + auto output = VisitBufferAccess(std::move(node)); + return std::move(output); + } + + Stmt VisitStmt_(const BufferStoreNode* op) final { + auto node = Downcast(StmtExprMutator::VisitStmt_(op)); + auto output = VisitBufferAccess(std::move(node)); + return std::move(output); + } + + template + Node VisitBufferAccess(Node node) { + Buffer new_buf = GetRemappedBuffer(node->buffer); + if (!new_buf.same_as(node->buffer)) { + auto writer = node.CopyOnWrite(); + writer->buffer = new_buf; } + + return node; } + + Buffer GetRemappedBuffer(Buffer buf) { + // Determine the buffer var that should be in the updated buffer, + // given the current scope. If no redefines are present, then the + // buffer var is unchanged. + Var new_buffer_var = buf->data; + auto var_it = scope_.find(buf->data.get()); + if (var_it != scope_.end() && !var_it->second.empty()) { + new_buffer_var = var_it->second.back(); + } + + // If no mapping is required, return the original buffer. + if (new_buffer_var.same_as(buf->data)) { + return buf; + } + + // If the current scope already has a mapping of this buffer, use + // the mapped buffer. + auto key = buf.get(); + std::vector& buffers = buf_remap_[key]; + if (buffers.size() && buffers.back()->data.same_as(new_buffer_var)) { + return buffers.back(); + } + + // Otherwise, make and return a new buffer object that uses the + // new buffer, pushing it onto the scoped stack of existing + // buffers. This will be popped when the new_buffer_var + // redefinition is popped. + Buffer new_buf(new_buffer_var, buf->dtype, buf->shape, buf->strides, buf->elem_offset, + buf->name, buf->data_alignment, buf->offset_factor, buf->buffer_type, + buf->axis_separators, buf->span); + buffers.push_back(new_buf); + return new_buf; + } + Stmt VisitStmt_(const LetStmtNode* op) final { const Var& v = op->var; if (defined_.count(v.get())) { PrimExpr value = this->VisitExpr(op->value); - Var new_var(v->name_hint, v.dtype()); - scope_[v.get()].push_back(new_var); + ScopedRedefine redefine(this, v); Stmt body = this->VisitStmt(op->body); - scope_[v.get()].pop_back(); - return LetStmt(new_var, value, body); + return LetStmt(redefine.new_var, value, body); } else { defined_.insert(v.get()); return StmtExprMutator::VisitStmt_(op); @@ -148,12 +192,10 @@ class IRConvertSSA final : public StmtExprMutator { Stmt VisitStmt_(const ForNode* op) final { const Var& v = op->loop_var; if (defined_.count(v.get())) { - Var new_var(v->name_hint, v.dtype()); - scope_[v.get()].push_back(new_var); + ScopedRedefine redefine(this, v); Stmt stmt = StmtExprMutator::VisitStmt_(op); - scope_[v.get()].pop_back(); op = stmt.as(); - return For(new_var, op->min, op->extent, op->kind, op->body, op->thread_binding, + return For(redefine.new_var, op->min, op->extent, op->kind, op->body, op->thread_binding, op->annotations); } else { defined_.insert(v.get()); @@ -163,12 +205,10 @@ class IRConvertSSA final : public StmtExprMutator { Stmt VisitStmt_(const AllocateNode* op) final { const Var& v = op->buffer_var; if (defined_.count(v.get())) { - Var new_var(v->name_hint, v->type_annotation); - scope_[v.get()].push_back(new_var); + ScopedRedefine redefine(this, v); Stmt stmt = StmtExprMutator::VisitStmt_(op); - scope_[v.get()].pop_back(); op = stmt.as(); - return Allocate(new_var, op->dtype, op->extents, op->condition, op->body); + return Allocate(redefine.new_var, op->dtype, op->extents, op->condition, op->body); } else { defined_.insert(v.get()); return StmtExprMutator::VisitStmt_(op); @@ -189,8 +229,34 @@ class IRConvertSSA final : public StmtExprMutator { } private: + struct ScopedRedefine { + ScopedRedefine(IRConvertSSA* parent, Var old_var) : parent(parent), old_var(old_var) { + if (old_var->type_annotation.defined()) { + new_var = Var(old_var->name_hint, old_var->type_annotation); + } else { + new_var = Var(old_var->name_hint, old_var->dtype); + } + parent->scope_[old_var.get()].push_back(new_var); + } + + ~ScopedRedefine() { + parent->scope_[old_var.get()].pop_back(); + for (auto& kv : parent->buf_remap_) { + std::vector& buffers = kv.second; + if (buffers.size() && (buffers.back()->data.get() == new_var.get())) { + buffers.pop_back(); + } + } + } + + IRConvertSSA* parent; + Var old_var; + Var new_var; + }; + std::unordered_map> scope_; std::unordered_set defined_; + std::unordered_map> buf_remap_; }; Stmt ConvertSSA(Stmt stmt) { return IRConvertSSA()(std::move(stmt)); } diff --git a/src/tir/transforms/ir_utils.h b/src/tir/transforms/ir_utils.h index d7ae362b64d4..2234cc22bcfa 100644 --- a/src/tir/transforms/ir_utils.h +++ b/src/tir/transforms/ir_utils.h @@ -103,9 +103,11 @@ inline PrimExpr TVMStructGet(DataType dtype, Var handle, int index, * \param offset the offset index. */ inline PrimExpr AddressOffset(Var handle, DataType dtype, int offset) { - return Call(DataType::Handle(), builtin::address_of(), - {Load(dtype, handle, make_const(DataType::Int(32), offset * dtype.lanes()), - const_true(dtype.lanes()))}); + PrimExpr offset_expr = make_const(DataType::Int(32), offset * dtype.lanes()); + Buffer dummy_buf(handle, dtype, {offset_expr + 1}, {}, 0, handle->name_hint, 0, 0, kDefault); + BufferLoad buf_load(dummy_buf, {offset_expr}); + + return Call(DataType::Handle(), builtin::address_of(), {buf_load}); } /*! @@ -119,8 +121,12 @@ inline PrimExpr AddressOffset(Var handle, DataType dtype, PrimExpr offset) { offset = offset * make_const(offset.dtype(), dtype.lanes()); offset = Ramp(offset, make_const(offset.dtype(), 1), dtype.lanes()); } - return Call(DataType::Handle(), builtin::address_of(), - {Load(dtype, handle, offset, const_true(dtype.lanes()))}); + + Buffer dummy_buf(handle, dtype.element_of(), {offset + 1}, {}, 0, handle->name_hint, 0, 0, + kDefault); + BufferLoad buf_load(dummy_buf, {offset}); + + return Call(DataType::Handle(), builtin::address_of(), {buf_load}); } /*! diff --git a/src/tir/transforms/lower_cross_thread_reduction.cc b/src/tir/transforms/lower_cross_thread_reduction.cc index 4df38ff543b5..df8bf69e7468 100644 --- a/src/tir/transforms/lower_cross_thread_reduction.cc +++ b/src/tir/transforms/lower_cross_thread_reduction.cc @@ -314,7 +314,7 @@ Stmt TransformReductionBlock(const BlockRealizeNode* realize, const Optionaldata); + parameters.push_back(BufferLoad(ct_buffer, {0})); // next arguments: all the reduction threads for (const ForNode* reduction_loop : reduction_loops) { if (reduction_loop->thread_binding.defined()) { diff --git a/src/tir/transforms/lower_custom_datatypes.cc b/src/tir/transforms/lower_custom_datatypes.cc index 21f1b18d523b..3cf5ed2ecf7c 100644 --- a/src/tir/transforms/lower_custom_datatypes.cc +++ b/src/tir/transforms/lower_custom_datatypes.cc @@ -103,32 +103,69 @@ class CustomDatatypesLowerer : public StmtExprMutator { } } - PrimExpr VisitExpr_(const LoadNode* load) final { - bool to_be_lowered = datatype::Registry::Global()->GetTypeRegistered(load->dtype.code()); - PrimExpr expr = StmtExprMutator::VisitExpr_(load); - load = expr.as(); - if (to_be_lowered) { - auto new_load_type = DataType::UInt(load->dtype.bits()); - auto buffer_var = load->buffer_var; - auto it = var_remap_.find(buffer_var); - if (it != var_remap_.end()) { - buffer_var = it->second; - } - return Load(new_load_type, buffer_var, load->index, load->predicate); - } - return expr; + PrimExpr VisitExpr_(const LoadNode* op) final { + LOG(FATAL) << "Unexpected use of deprecated LoadNode. Please use BufferLoadNode instead."; + return PrimExpr(); } Stmt VisitStmt_(const StoreNode* op) final { - Stmt ret = StmtExprMutator::VisitStmt_(op); - op = ret.as(); + LOG(FATAL) << "Unexpected use of deprecated StoreNode. Please use BufferStoreNode instead."; + return Stmt(); + } + + PrimExpr VisitExpr_(const BufferLoadNode* op) final { + auto node = Downcast(StmtExprMutator::VisitExpr_(op)); + auto modified = VisitBufferAccess(node); - auto it = var_remap_.find(op->buffer_var); - if (it != var_remap_.end()) { - return Store(it->second, op->value, op->index, op->predicate); + // Not needed for BufferStoreNode, so we can't just call + // LegalizeDtype() in VisitBufferAccess. + if (node.same_as(modified)) { + return std::move(node); } else { - return ret; + auto writer = modified.CopyOnWrite(); + writer->LegalizeDType(); + return std::move(modified); + } + } + + Stmt VisitStmt_(const BufferStoreNode* op) final { + auto node = Downcast(StmtExprMutator::VisitStmt_(op)); + return VisitBufferAccess(std::move(node)); + } + + template + Node VisitBufferAccess(Node node) { + Buffer new_buf = GetRemappedBuffer(node->buffer); + if (!new_buf.same_as(node->buffer)) { + auto writer = node.CopyOnWrite(); + writer->buffer = new_buf; + } + + return node; + } + + Buffer GetRemappedBuffer(Buffer buf) { + auto key = buf; + auto cache_it = buf_remap_.find(key); + if (cache_it != buf_remap_.end()) { + return cache_it->second; + } + + bool to_be_lowered = datatype::Registry::Global()->GetTypeRegistered(buf->dtype.code()); + + if (to_be_lowered) { + auto new_load_type = DataType::UInt(buf->dtype.bits()); + auto writer = buf.CopyOnWrite(); + writer->dtype = new_load_type; + + auto var_it = var_remap_.find(buf->data); + if (var_it != var_remap_.end()) { + writer->data = var_it->second; + } } + + buf_remap_[key] = buf; + return buf; } Stmt VisitStmt_(const AttrStmtNode* op) final { @@ -200,6 +237,7 @@ class CustomDatatypesLowerer : public StmtExprMutator { std::string target_; // remap buffer vars std::unordered_map var_remap_; + std::unordered_map buf_remap_; }; namespace transform { diff --git a/src/tir/transforms/lower_match_buffer.cc b/src/tir/transforms/lower_match_buffer.cc index 6bfbcef95fc5..5bde5cb90e2b 100644 --- a/src/tir/transforms/lower_match_buffer.cc +++ b/src/tir/transforms/lower_match_buffer.cc @@ -177,7 +177,7 @@ class MatchBufferLower : public StmtExprMutator { Bind(buffer->data, source_buffer->data, buffer->name + ".data"); // Step.2.2. Update element offset - // Note we create Load via vload and try to reuse index calculate. + // We use the ElemOffset method to avoid duplicating the index calculation. { Array indices; indices.reserve(source->region.size()); @@ -185,11 +185,18 @@ class MatchBufferLower : public StmtExprMutator { indices.push_back(range->min); } - Load load = Downcast(source_buffer.vload(indices, source_buffer->dtype)); - Bind(buffer->elem_offset, load->index, buffer->name + ".elem_offset"); - CHECK(analyzer_.CanProve(truncmod(buffer->elem_offset, buffer->offset_factor) == 0)) - << "The source elem_offset " << load->index << " does not satisfy the offset_factor " - << buffer->offset_factor << "."; + Array buffer_start_indices = source_buffer->ElemOffset(indices); + if (buffer_start_indices.size() == 1) { + Bind(buffer->elem_offset, buffer_start_indices[0], buffer->name + ".elem_offset"); + CHECK(analyzer_.CanProve(truncmod(buffer->elem_offset, buffer->offset_factor) == 0)) + << "The source elem_offset " << buffer_start_indices[0] + << " does not satisfy the offset_factor " << buffer->offset_factor << "."; + } else { + // Non-zero elem_offset is ill-defined for non-flat memory. + // If needed in the future, will require `Array + // elem_offsets`, with one offset for each flattened index. + Bind(buffer->elem_offset, 0); + } } // Step 2.3. Check and update strides diff --git a/src/tir/transforms/lower_thread_allreduce.cc b/src/tir/transforms/lower_thread_allreduce.cc index 1c6aa161e473..7e09943d0185 100644 --- a/src/tir/transforms/lower_thread_allreduce.cc +++ b/src/tir/transforms/lower_thread_allreduce.cc @@ -98,36 +98,97 @@ class ThreadAllreduceBuilder final : public StmtExprMutator { if (it != alloc_remap_.end()) { const AllocateNode* repl = it->second.as(); if (warp_allocs_.count(repl)) { - stmt = Allocate(repl->buffer_var, repl->dtype, repl->extents, repl->condition, op->body); new_storage_scopes_[repl->buffer_var.get()] = "local"; } else { - stmt = Allocate(repl->buffer_var, repl->dtype, repl->extents, repl->condition, op->body); new_storage_scopes_[repl->buffer_var.get()] = "shared"; } - return stmt; + return Allocate(repl->buffer_var, repl->dtype, repl->extents, repl->condition, op->body); } else { return stmt; } } + PrimExpr VisitExpr_(const LoadNode* op) final { - auto it = load_remap_.find(op->buffer_var.get()); - if (it != load_remap_.end()) { - ICHECK(is_zero(op->index)); - return it->second; - } else { - return StmtExprMutator::VisitExpr_(op); - } + LOG(FATAL) << "Unexpected use of deprecated LoadNode. Please use BufferLoadNode instead."; + return PrimExpr(); } Stmt VisitStmt_(const StoreNode* op) final { - auto it = store_remap_.find(op->buffer_var.get()); + LOG(FATAL) << "Unexpected use of deprecated StoreNode. Please use BufferStoreNode instead."; + return Stmt(); + } + + PrimExpr VisitExpr_(const BufferLoadNode* op) final { + { + auto it = load_remap_.find(op->buffer->data.get()); + if (it != load_remap_.end()) { + for (const auto& index : op->indices) { + ICHECK(is_zero(index)); + } + return it->second; + } + } + + BufferLoad load = Downcast(StmtExprMutator::VisitExpr_(op)); + op = load.get(); + + { + auto it = buf_remap_.find(op->buffer.get()); + if (it != buf_remap_.end()) { + return BufferLoad(it->second, op->indices, op->span); + } + } + + { + auto it = var_remap_.find(op->buffer->data.get()); + if (it != var_remap_.end()) { + Buffer remapped_buffer(it->second, op->buffer->dtype, op->buffer->shape, + op->buffer->strides, op->buffer->elem_offset, op->buffer->name, + op->buffer->data_alignment, op->buffer->offset_factor, + op->buffer->buffer_type, op->buffer->axis_separators, + op->buffer->span); + buf_remap_[op->buffer.get()] = remapped_buffer; + return BufferLoad(remapped_buffer, op->indices, op->span); + } + } + return StmtExprMutator::VisitExpr_(op); + } + + Stmt VisitStmt_(const BufferStoreNode* op) final { + BufferStore store = Downcast(StmtExprMutator::VisitStmt_(op)); + + auto it = store_remap_.find(store->buffer.get()); if (it != store_remap_.end()) { - ICHECK(is_zero(op->index)); - auto value = StmtExprMutator::VisitExpr(op->value); - return Store(it->second, value, 0, op->predicate); - } else { - return StmtExprMutator::VisitStmt_(op); + for (const auto& index : op->indices) { + ICHECK(is_zero(index)); + } + + auto writer = store.CopyOnWrite(); + writer->buffer = it->second; + return std::move(store); + } + + { + auto it = buf_remap_.find(store->buffer.get()); + if (it != buf_remap_.end()) { + return BufferStore(it->second, store->value, store->indices, store->span); + } } + + { + auto it = var_remap_.find(store->buffer->data.get()); + if (it != var_remap_.end()) { + Buffer remapped_buffer(it->second, store->buffer->dtype, store->buffer->shape, + store->buffer->strides, store->buffer->elem_offset, + store->buffer->name, store->buffer->data_alignment, + store->buffer->offset_factor, store->buffer->buffer_type, + store->buffer->axis_separators, store->buffer->span); + buf_remap_[store->buffer.get()] = remapped_buffer; + return BufferStore(remapped_buffer, store->value, store->indices, store->span); + } + } + + return std::move(store); } std::unordered_map new_storage_scopes_; @@ -164,11 +225,15 @@ class ThreadAllreduceBuilder final : public StmtExprMutator { } types[idx] = values[idx].dtype(); } - std::vector buffers(size); + std::vector buffers(size); for (size_t idx = 0; idx < size; ++idx) { - const VarNode* buffer = call->args[2 + size + idx].as(); - ICHECK(buffer); - buffers[idx] = buffer; + PrimExpr arg = call->args[2 + size + idx]; + // Loads from boolean buffers may have cast nodes inserted by + // earlier passes. + if (auto cast = arg.as()) { + arg = cast->value; + } + buffers[idx] = Downcast(arg)->buffer; } std::unordered_set reduce_set; @@ -246,8 +311,9 @@ class ThreadAllreduceBuilder final : public StmtExprMutator { } std::vector seq; - std::vector shared_bufs(size); - std::vector local_vars; + std::vector shared_buffer_vars(size); + std::vector shared_bufs(size); + std::vector local_bufs; // // This is an optimization. For small reduction sizes, it may be beneficial // for a single warp to performance the entire reduction. No trips to shared @@ -271,19 +337,23 @@ class ThreadAllreduceBuilder final : public StmtExprMutator { // This is the index to the reduction variable, one reduction // variable per warp. Local scope seems easier to reason without // relying on a pattern match pass to fix it later. - PrimExpr index(0); + Array zero_indices = {0}; for (size_t idx = 0; idx < size; ++idx) { - Type ptr_type = PointerType(PrimType(types[idx])); - shared_bufs[idx] = Var("red_buf" + std::to_string(idx), ptr_type); + Array shape = {1}; + + Buffer buffer = decl_buffer(shape, types[idx], "red_buf" + std::to_string(idx)); + Var buffer_var = buffer->data; + + shared_buffer_vars[idx] = buffer_var; + shared_bufs[idx] = buffer; + PrimExpr pred = const_true(types[idx].lanes()); - seq.emplace_back(Store(shared_bufs[idx], values[idx], index, pred)); + seq.emplace_back(BufferStore(shared_bufs[idx], values[idx], zero_indices)); - // Uses a local variable to store the shuffled data. - // Later on, this allocation will be properly attached to this statement. - Var var("t" + std::to_string(idx), ptr_type); - Stmt s = Allocate(var, types[idx], {PrimExpr(1)}, pred, Evaluate(0)); - local_vars.push_back(s); + // Uses a local variable to store the shuffled data. Later + // on, an allocation will be built for this local variable. + local_bufs.push_back(decl_buffer(shape, types[idx], "t" + std::to_string(idx))); } // The mask for this reducer, as this reducer may sit inside @@ -291,18 +361,16 @@ class ThreadAllreduceBuilder final : public StmtExprMutator { // active channels. // DataType mask_dtype = DataType::UInt(32); - Var mask_var("mask", PointerType(PrimType(mask_dtype))); + Buffer mask_buffer = decl_buffer({1}, mask_dtype, "mask"); { - PrimExpr pred = const_true(1); PrimExpr mask = Call(mask_dtype, builtin::tvm_warp_activemask(), {}); if (group_extent > 1) { mask = mask & (((1 << reduce_extent) - 1) << (reduce_extent * group_index)); } - seq.emplace_back(Store(mask_var, mask, index, pred)); - // Push allocation with an empty body. Later this will be fixed - // when the entire body is ready. - auto stmt = Allocate(mask_var, mask_dtype, {PrimExpr(1)}, pred, Evaluate(0)); - local_vars.push_back(stmt); + seq.emplace_back(BufferStore(mask_buffer, mask, zero_indices)); + // Push the buffer description. Later this will have an + // allocation built for it. + local_bufs.push_back(mask_buffer); } // Emit reductions within a warp. @@ -314,9 +382,9 @@ class ThreadAllreduceBuilder final : public StmtExprMutator { // Load reduction values, no synchronization needed. Array a, b; for (size_t i = 0; i < size; ++i) { - Var var = shared_bufs[i]; - PrimExpr pred = const_true(types[i].lanes()); - PrimExpr val = Load(types[i], var, index, pred); + Buffer shared_buf = shared_bufs[i]; + BufferLoad val(shared_buf, zero_indices); + ICHECK_EQ(val->dtype, types[i]); a.push_back(val); // __shfl_*sync calls shall not appear in if_then_else expressions @@ -332,12 +400,13 @@ class ThreadAllreduceBuilder final : public StmtExprMutator { // The former may cause dead lock as there is a divergent // branch with a warp sync call inside. // - PrimExpr other = WarpShuffle(builtin::tvm_warp_shuffle_down(), mask_var, val, offset); - const AllocateNode* repl = local_vars[i].as(); - Stmt s = Store(repl->buffer_var, other, index, pred); + PrimExpr other = WarpShuffle(builtin::tvm_warp_shuffle_down(), mask_buffer, val, offset); + Buffer local_buf = local_bufs[i]; + Stmt s = BufferStore(local_buf, other, zero_indices); seq.push_back(s); - PrimExpr load = Load(types[i], repl->buffer_var, index, pred); + BufferLoad load = BufferLoad(local_buf, zero_indices); + ICHECK_EQ(load->dtype, types[i]); b.push_back(load); } @@ -347,9 +416,8 @@ class ThreadAllreduceBuilder final : public StmtExprMutator { // Store the reduction result to itself. std::vector stores(size); for (size_t i = 0; i < size; ++i) { - Var var = shared_bufs[i]; - PrimExpr pred = const_true(types[i].lanes()); - stores[i] = Store(var, ret[i], index, pred); + Buffer buf = shared_bufs[i]; + stores[i] = BufferStore(buf, ret[i], zero_indices); } seq.push_back(SeqStmt::Flatten(stores)); } @@ -359,34 +427,35 @@ class ThreadAllreduceBuilder final : public StmtExprMutator { // uniformly writting the same result. // for (size_t i = 0; i < size; ++i) { - Var var = shared_bufs[i]; - PrimExpr pred = const_true(types[i].lanes()); - PrimExpr val = Load(types[i], var, index, pred); + Buffer buf = shared_bufs[i]; + PrimExpr val = BufferLoad(buf, zero_indices); + ICHECK_EQ(val->dtype, types[i]); PrimExpr splat = - WarpShuffle(builtin::tvm_warp_shuffle(), mask_var, val, reduce_extent * group_index); - seq.push_back(Store(var, splat, index, pred)); + WarpShuffle(builtin::tvm_warp_shuffle(), mask_buffer, val, reduce_extent * group_index); + seq.push_back(BufferStore(buf, splat, zero_indices)); } // Update existing allocations. for (size_t i = 0; i < size; ++i) { - ICHECK(!load_remap_.count(buffers[i])); + ICHECK(!load_remap_.count(buffers[i]->data.get())); PrimExpr pred = const_true(types[i].lanes()); - Var var = shared_bufs[i]; - load_remap_[buffers[i]] = Load(types[i], var, index, pred); - store_remap_[buffers[i]] = var; + Buffer buf = shared_bufs[i]; + PrimExpr val = BufferLoad(buf, zero_indices); + ICHECK_EQ(val->dtype, types[i]); + load_remap_[buffers[i]->data.get()] = val; + store_remap_[buffers[i].get()] = buf; Array extents{PrimExpr(1)}; - auto node = Allocate(var, types[i], extents, pred, Evaluate(0)); - alloc_remap_[buffers[i]] = node; + auto node = Allocate(buf->data, types[i], extents, pred, Evaluate(0)); + alloc_remap_[buffers[i]->data.get()] = node; + var_remap_[buffers[i]->data.get()] = buf->data; warp_allocs_.insert(node.get()); } } else { if (reduce_extent == 1) { // special case, no reduction is needed. - std::vector stores(size); + std::vector stores; for (size_t i = 0; i < size; ++i) { - PrimExpr pred = const_true(types[i].lanes()); - Var buffer_var = Downcast(call->args[2 + size + i]); - stores[i] = Store(buffer_var, values[i], 0, pred); + stores.push_back(BufferStore(buffers[i], values[i], {0})); } return SeqStmt::Flatten(stores); } @@ -394,35 +463,38 @@ class ThreadAllreduceBuilder final : public StmtExprMutator { // previous iteration on the same buffer. seq.emplace_back(SyncThread("shared")); for (size_t idx = 0; idx < size; ++idx) { - shared_bufs[idx] = Var("red_buf" + std::to_string(idx), PointerType(PrimType(types[idx]))); + Buffer buffer = decl_buffer({1}, types[idx], "red_buf" + std::to_string(idx)); + + shared_bufs[idx] = buffer; + shared_buffer_vars[idx] = buffer->data; + PrimExpr pred = const_true(types[idx].lanes()); - seq.emplace_back(Store(shared_bufs[idx], values[idx], - BufIndex(reduce_index, group_index, reduce_extent), pred)); + seq.emplace_back(BufferStore(shared_bufs[idx], values[idx], + {BufIndex(reduce_index, group_index, reduce_extent)})); } seq.emplace_back(SyncThread("shared")); seq.emplace_back(MakeBufAllreduce(combiner, types, shared_bufs, reduce_index, group_index, reduce_extent, group_extent, contiguous_reduce_extent)); for (size_t idx = 0; idx < size; ++idx) { - ICHECK(!load_remap_.count(buffers[idx])); + ICHECK(!load_remap_.count(buffers[idx]->data.get())); PrimExpr pred = const_true(types[idx].lanes()); - load_remap_[buffers[idx]] = - Load(types[idx], shared_bufs[idx], - BufIndex(make_zero(reduce_index.dtype()), group_index, reduce_extent), pred); - alloc_remap_[buffers[idx]] = - Allocate(shared_bufs[idx], types[idx], + BufferLoad load(shared_bufs[idx], + {BufIndex(make_zero(reduce_index.dtype()), group_index, reduce_extent)}); + ICHECK_EQ(load->dtype, types[idx]); + load_remap_[buffers[idx]->data.get()] = load; + alloc_remap_[buffers[idx]->data.get()] = + Allocate(shared_bufs[idx]->data, types[idx], {PrimExpr(group_extent), PrimExpr(reduce_extent)}, pred, Evaluate(0)); - store_remap_[buffers[idx]] = shared_bufs[idx]; + var_remap_[buffers[idx]->data.get()] = shared_bufs[idx]->data; + store_remap_[buffers[idx].get()] = shared_bufs[idx]; } } // Fix all local allocations as all statements are built. Stmt body = SeqStmt::Flatten(seq); - for (auto var : local_vars) { - const AllocateNode* repl = var.as(); - if (repl) { - body = Allocate(repl->buffer_var, repl->dtype, repl->extents, repl->condition, body); - new_storage_scopes_[repl->buffer_var.get()] = "local"; - } + for (Buffer buf : local_bufs) { + body = Allocate(buf->data, buf->dtype, buf->shape, const_true(buf->dtype.lanes()), body); + new_storage_scopes_[buf->data.get()] = "local"; } return body; @@ -430,8 +502,9 @@ class ThreadAllreduceBuilder final : public StmtExprMutator { // make allreduce. Stmt MakeBufAllreduce(const CommReducerNode* combiner, const std::vector& types, - const Array& shared_bufs, PrimExpr reduce_index, PrimExpr group_index, - int reduce_extent, int group_extent, int contiguous_reduce_extent) { + const Array& shared_bufs, PrimExpr reduce_index, + PrimExpr group_index, int reduce_extent, int group_extent, + int contiguous_reduce_extent) { // Get next power of two int reduce_align = 1; while (reduce_extent > reduce_align) { @@ -446,10 +519,14 @@ class ThreadAllreduceBuilder final : public StmtExprMutator { auto fload = [&](int offset) { Array a, b; for (size_t i = 0; i < size; ++i) { - b.push_back(Load(types[i], shared_bufs[i], - BufIndex(reduce_index + offset, group_index, reduce_extent), - const_true())); - a.push_back(Load(types[i], shared_bufs[i], buf_index, const_true())); + BufferLoad b_load(shared_bufs[i], + {BufIndex(reduce_index + offset, group_index, reduce_extent)}); + ICHECK_EQ(b_load->dtype, types[i]); + b.push_back(b_load); + + BufferLoad a_load(shared_bufs[i], {buf_index}); + ICHECK_EQ(a_load->dtype, types[i]); + a.push_back(a_load); } Array ret = (*combiner)(a, b); return ret; @@ -457,7 +534,7 @@ class ThreadAllreduceBuilder final : public StmtExprMutator { auto fstore = [&](const Array& ret) { std::vector stores(size); for (size_t i = 0; i < size; ++i) { - stores[i] = Store(shared_bufs[i], ret[i], buf_index, const_true()); + stores[i] = BufferStore(shared_bufs[i], ret[i], {buf_index}); } return SeqStmt::Flatten(stores); }; @@ -567,10 +644,9 @@ class ThreadAllreduceBuilder final : public StmtExprMutator { } // Emit warp shuffle calls. - PrimExpr WarpShuffle(const Op& op, Var mask_var, PrimExpr val, PrimExpr delta_or_lane) { - PrimExpr pred = const_true(1); - PrimExpr index(0); - PrimExpr mask = Load(DataType::UInt(32), mask_var, index, pred); + PrimExpr WarpShuffle(const Op& op, Buffer mask_buffer, PrimExpr val, PrimExpr delta_or_lane) { + Array indices = {0}; + PrimExpr mask = BufferLoad(mask_buffer, indices); PrimExpr width = IntImm(DataType::Int(32), warp_size_); Array args{mask, val, delta_or_lane, width, width}; return Call(val.dtype(), op, args); @@ -640,9 +716,13 @@ class ThreadAllreduceBuilder final : public StmtExprMutator { // The load remap std::unordered_map load_remap_; // The store remap - std::unordered_map store_remap_; + std::unordered_map store_remap_; // Allocate remap std::unordered_map alloc_remap_; + // BufferVar remap + std::unordered_map var_remap_; + // Buffer remap + std::unordered_map buf_remap_; // Allocate from warp reductions std::unordered_set warp_allocs_; // Internal analyzer diff --git a/src/tir/transforms/lower_tvm_builtin.cc b/src/tir/transforms/lower_tvm_builtin.cc index bcf763ca8a93..7f0631d00e57 100644 --- a/src/tir/transforms/lower_tvm_builtin.cc +++ b/src/tir/transforms/lower_tvm_builtin.cc @@ -34,16 +34,125 @@ namespace tvm { namespace tir { +class StackSizeChecker : public StmtExprVisitor { + public: + struct StackSizes { + // If a tvm_stack_make_shape call has no arguments, it is still + // valid and represents a scalar shape (). Therefore, -1 is used + // to represent "no shape arguments exist", while 0 represents + // "shape arguments exist, all of which are size 0". + int64_t shape_stack{-1}; + uint64_t array_stack{0}; + uint64_t arg_stack{0}; + }; + + static StackSizes Check(Stmt stmt) { + StackSizeChecker visitor; + visitor.VisitStmt(stmt); + return visitor.max_stack_; + } + + private: + void VisitStmt_(const ForNode* op) final { + if (op->kind == ForKind::kParallel) { + // Parallel for loops have their own stack and allocations, so + // stop the recursion here. + return; + } else { + this->VisitStmt(op->body); + } + } + void VisitExpr_(const CallNode* op) final { + if (op->op.same_as(builtin::tvm_call_packed())) { + return MakeCallPacked(op, /* use_string_lookup */ true); + } else if (op->op.same_as(builtin::tvm_call_cpacked())) { + return MakeCallPacked(op, /* use_string_lookup */ false); + } else if (op->op.same_as(builtin::tvm_call_trace_packed())) { + return MakeCallTracePacked(op); + } else if (op->op.same_as(builtin::tvm_stack_make_shape())) { + return MakeShape(op); + } else if (op->op.same_as(builtin::tvm_stack_make_array())) { + return MakeArray(op); + } else { + return StmtExprVisitor::VisitExpr_(op); + } + } + // call shape + void MakeShape(const CallNode* op) { + // if args.size() == 0, it is still valid and represents a scalar + // shape (). Therefore, -1 is used to represent "no shape + // arguments exist", while 0 represents "shape arguments exist, + // all of which are size 0". + if (current_stack_.shape_stack == -1) { + current_stack_.shape_stack = 0; + } + current_stack_.shape_stack += op->args.size(); + StmtExprVisitor::VisitExpr_(op); + } + // make array + void MakeArray(const CallNode* op) { + current_stack_.array_stack += 1; + StmtExprVisitor::VisitExpr_(op); + } + // call packed. + void MakeCallPacked(const CallNode* op, bool use_string_lookup) { + StackSizes restore_stack = current_stack_; + + size_t arg_count = op->args.size(); + + // cpacked expects a resource_handle parameter + if (!use_string_lookup) { + arg_count--; + } + + current_stack_.arg_stack += arg_count; + // Specially handle the buffer packed intrinsic + StmtExprVisitor::VisitExpr_(op); + // Record the amount of stack space needed, then reset the stack + // position to its previous location. + UpdateMaxStack(); + current_stack_ = restore_stack; + } + + void MakeCallTracePacked(const CallNode* op) { + StackSizes restore_stack = current_stack_; + + size_t args_size = op->args.size(); + ICHECK_GT(args_size, 0); + current_stack_.arg_stack += args_size; + + StmtExprVisitor::VisitExpr_(op); + // Record the amount of stack space needed, then reset the stack + // position to its previous location. + UpdateMaxStack(); + current_stack_ = restore_stack; + + // However, the arguments to this CallNode remain on top of the + // stack, so we can use more than one packed function's arguments + // with the one stack. + current_stack_.arg_stack = restore_stack.arg_stack + args_size - 1; + } + + void UpdateMaxStack() { + max_stack_.arg_stack = std::max(current_stack_.arg_stack, max_stack_.arg_stack); + max_stack_.shape_stack = std::max(current_stack_.shape_stack, max_stack_.shape_stack); + max_stack_.array_stack = std::max(current_stack_.array_stack, max_stack_.array_stack); + } + + StackSizes current_stack_; + StackSizes max_stack_; +}; + // Calculate the statistics of packed function. // These information are needed during codegen. class BuiltinLower : public StmtExprMutator { public: // Record stack frame for existing scope. struct AllocaScope { - Var stack_shape = Var("stack_shape", DataType::Handle()); + Buffer stack_shape; Var stack_array = Var("stack_array", DataType::Handle()); Var stack_value = Var("stack_value", DataType::Handle()); - Var stack_tcode = Var("stack_tcode", DataType::Handle()); + Buffer stack_tcode; int64_t max_shape_stack{-1}; uint64_t max_array_stack{0}; @@ -58,21 +167,41 @@ class BuiltinLower : public StmtExprMutator { // Allcoate stack frames, only at parallel-for or root. Stmt VisitBodyAndRealizeAlloca(Stmt stmt) { + // Initial check to identify maximum stack sizes. These are used + // to construct Buffer objects to hold the stack, which are then + // used when mutating. + auto max_sizes = StackSizeChecker::Check(stmt); + alloca_scope_.emplace_back(); - stmt = this->VisitStmt(stmt); - ICHECK(!alloca_scope_.empty()); auto& scope = alloca_scope_.back(); - if (scope.max_shape_stack != -1) { - stmt = LetStmt(scope.stack_shape, StackAlloca("shape", scope.max_shape_stack), stmt); + + if (max_sizes.shape_stack != -1) { + scope.stack_shape = decl_buffer({IntImm(DataType::Int(64), max_sizes.shape_stack)}, + DataType::Int(64), "stack_shape"); + stmt = LetStmt(scope.stack_shape->data, StackAlloca("shape", max_sizes.shape_stack), stmt); } - if (scope.max_array_stack != 0) { - stmt = LetStmt(scope.stack_array, StackAlloca("array", scope.max_array_stack), stmt); + if (max_sizes.array_stack != 0) { + stmt = LetStmt(scope.stack_array, StackAlloca("array", max_sizes.array_stack), stmt); } - if (scope.max_arg_stack != 0) { - stmt = LetStmt(scope.stack_value, StackAlloca("arg_value", scope.max_arg_stack), stmt); - stmt = LetStmt(scope.stack_tcode, StackAlloca("arg_tcode", scope.max_arg_stack), stmt); + + if (max_sizes.arg_stack != 0) { + scope.stack_tcode = decl_buffer({IntImm(DataType::UInt(64), max_sizes.arg_stack)}, + DataType::Int(32), "stack_tcode"); + stmt = LetStmt(scope.stack_value, StackAlloca("arg_value", max_sizes.arg_stack), stmt); + + stmt = LetStmt(scope.stack_tcode->data, StackAlloca("arg_tcode", max_sizes.arg_stack), stmt); } + + // Copy these values from the earlier search, for use in bounds + // checks. + scope.max_shape_stack = max_sizes.shape_stack; + scope.max_array_stack = max_sizes.array_stack; + scope.max_arg_stack = max_sizes.arg_stack; + + stmt = this->VisitStmt(stmt); + + ICHECK(!alloca_scope_.empty()); alloca_scope_.pop_back(); return stmt; @@ -244,10 +373,10 @@ class BuiltinLower : public StmtExprMutator { op = expr.as(); // no need to perform any store for a scalar shape for (size_t i = 0; i < op->args.size(); ++i) { - prep_seq.emplace_back(Store(scope.stack_shape, cast(DataType::Int(64), op->args[i]), - ConstInt32(stack_begin + i), const_true(1))); + prep_seq.emplace_back(BufferStore(scope.stack_shape, cast(DataType::Int(64), op->args[i]), + {ConstInt32(stack_begin + i)})); } - return AddressOffset(scope.stack_shape, DataType::Int(64), stack_begin); + return AddressOffset(scope.stack_shape->data, DataType::Int(64), stack_begin); } // make array PrimExpr MakeArray(const CallNode* op) { @@ -328,17 +457,16 @@ class BuiltinLower : public StmtExprMutator { arg_tcode = kTVMStr; } if (IsArrayHandle(arg)) arg_tcode = kTVMDLTensorHandle; - prep_seq.emplace_back( - Store(scope.stack_tcode, ConstInt32(arg_tcode), stack_index, const_true(1))); + prep_seq.emplace_back(BufferStore(scope.stack_tcode, ConstInt32(arg_tcode), {stack_index})); } - // UPDATE stack value - scope.max_arg_stack = std::max(scope.run_arg_stack, scope.max_arg_stack); - scope.max_shape_stack = std::max(scope.run_shape_stack, scope.max_shape_stack); - scope.max_array_stack = std::max(scope.run_array_stack, scope.max_array_stack); + // Verify stack size matches earlier value. + ICHECK_LE(scope.run_arg_stack, scope.max_arg_stack); + ICHECK_LE(scope.run_shape_stack, scope.max_shape_stack); + ICHECK_LE(scope.run_array_stack, scope.max_array_stack); scope.run_shape_stack = restore_shape_stack; scope.run_array_stack = restore_array_stack; scope.run_arg_stack = arg_stack_begin; - Array packed_args = {op->args[0], scope.stack_value, scope.stack_tcode, + Array packed_args = {op->args[0], scope.stack_value, scope.stack_tcode->data, ConstInt32(arg_stack_begin), ConstInt32(arg_stack_begin + op->args.size() - 1)}; @@ -379,19 +507,18 @@ class BuiltinLower : public StmtExprMutator { builtin::kTVMValueContent, arg)); int arg_tcode = api_type.code(); ICHECK(!IsArrayHandle(arg)) << "Trace does not support Buffers"; - prep_seq.emplace_back( - Store(scope.stack_tcode, ConstInt32(arg_tcode), stack_index, const_true(1))); + prep_seq.emplace_back(BufferStore(scope.stack_tcode, ConstInt32(arg_tcode), {stack_index})); } - // UPDATE stack value - scope.max_arg_stack = std::max(scope.run_arg_stack, scope.max_arg_stack); - scope.max_shape_stack = std::max(scope.run_shape_stack, scope.max_shape_stack); - scope.max_array_stack = std::max(scope.run_array_stack, scope.max_array_stack); + // Verify stack size matches earlier value. + ICHECK_LE(scope.run_arg_stack, scope.max_arg_stack); + ICHECK_LE(scope.run_shape_stack, scope.max_shape_stack); + ICHECK_LE(scope.run_array_stack, scope.max_array_stack); scope.run_shape_stack = restore_shape_stack; scope.run_array_stack = restore_array_stack; // Update the top of the stack, so we can use more than one // packed function's arguments with the one stack. scope.run_arg_stack = arg_stack_begin + args_size - 1; - Array packed_args = {op->args[0], scope.stack_value, scope.stack_tcode, + Array packed_args = {op->args[0], scope.stack_value, scope.stack_tcode->data, ConstInt32(arg_stack_begin), ConstInt32(arg_stack_begin + op->args.size() - 1), // Pass traced value. diff --git a/src/tir/transforms/lower_warp_memory.cc b/src/tir/transforms/lower_warp_memory.cc index f316ae9606d0..40971114d416 100644 --- a/src/tir/transforms/lower_warp_memory.cc +++ b/src/tir/transforms/lower_warp_memory.cc @@ -114,19 +114,31 @@ class WarpStoreCoeffFinder : private StmtVisitor { private: /// Visitor implementation void VisitStmt_(const StoreNode* op) final { - if (op->buffer_var.get() == buffer_) { - if (op->value.dtype().lanes() == 1) { - UpdatePattern(op->index); - } else { - arith::PVar base; - ICHECK(arith::ramp(base, 1, op->value.dtype().lanes()).Match(op->index)) - << "LowerWarpMemory failed due to store index=" << op->index - << ", can only handle continuous store"; - UpdatePattern(base.Eval()); - } - } else { + LOG(FATAL) << "Unexpected use of deprecated StoreNode. Please use BufferStoreNode instead."; + } + + void VisitStmt_(const BufferStoreNode* op) final { + if (op->buffer->data.get() != buffer_) { StmtVisitor::VisitStmt_(op); + return; + } + + ICHECK_EQ(op->indices.size(), 1) << "Expected flat memory to use as warp memory. " + << "Has StorageFlatten (TE-based schedule) or " + << "FlattenBuffer (TIR-based schedules) been run?"; + + PrimExpr index = op->indices[0]; + if (op->value.dtype().lanes() != 1) { + arith::PVar base; + ICHECK(arith::ramp(base, 1, op->value.dtype().lanes()).Match(index)) + << "LowerWarpMemory failed due to store index=" << index + << ", can only handle continuous store"; + UpdatePattern(base.Eval()); + + index = base.Eval(); } + + UpdatePattern(index); } void UpdatePattern(const PrimExpr& index) { @@ -239,35 +251,62 @@ class WarpAccessRewriter : protected StmtExprMutator { } Stmt VisitStmt_(const StoreNode* op) override { - if (op->buffer_var.get() == buffer_) { - PrimExpr local_index, group; - std::tie(local_index, group) = SplitIndexByGroup(op->index); - PrimExpr new_value = VisitExpr(op->value); - return Store(op->buffer_var, new_value, local_index, op->predicate); - } else { - return StmtExprMutator::VisitStmt_(op); - } + LOG(FATAL) << "Unexpected use of deprecated StoreNode. Please use BufferStoreNode instead."; + return Stmt(); } PrimExpr VisitExpr_(const LoadNode* op) override { - if (op->buffer_var.get() == buffer_) { + LOG(FATAL) << "Unexpected use of deprecated LoadNode. Please use BufferLoadNode instead."; + return PrimExpr(); + } + + Stmt VisitStmt_(const BufferStoreNode* op) override { + auto store = Downcast(StmtExprMutator::VisitStmt_(op)); + + if (store->buffer->data.get() == buffer_) { + ICHECK_EQ(store->indices.size(), 1) << "Expected flat memory to use as warp memory. " + << "Has StorageFlatten (TE-based schedule) or " + << "FlattenBuffer (TIR-based schedules) been run?"; + PrimExpr local_index, group; - std::tie(local_index, group) = SplitIndexByGroup(op->index); - // invariance: local index must do not contain warp id - ICHECK(!UsesVar(local_index, [this](const VarNode* var) { return var == warp_index_.get(); })) - << "LowerWarpMemory failed to rewrite load to shuffle for index " << op->index - << " local_index=" << local_index; - PrimExpr load_value = Load(op->dtype, op->buffer_var, local_index, op->predicate); - if (analyzer_->CanProveEqual(group, warp_index_)) { - return load_value; - } - PrimExpr mask = Call(DataType::UInt(32), builtin::tvm_warp_activemask(), {}); - return Call(load_value.dtype(), builtin::tvm_warp_shuffle(), - {mask, load_value, group, width_, warp_size_}); - } else { - return StmtExprMutator::VisitExpr_(op); + std::tie(local_index, group) = SplitIndexByGroup(store->indices[0]); + + auto writer = store.CopyOnWrite(); + writer->indices = {local_index}; + } + + return std::move(store); + } + + PrimExpr VisitExpr_(const BufferLoadNode* op) override { + auto load = Downcast(StmtExprMutator::VisitExpr_(op)); + + if (load->buffer->data.get() != buffer_) { + return std::move(load); + } + + ICHECK_EQ(op->indices.size(), 1) << "Expected flat memory to use as warp memory. " + << "Has StorageFlatten (TE-based schedule) or " + << "FlattenBuffer (TIR-based schedules) been run?"; + + PrimExpr local_index, group; + std::tie(local_index, group) = SplitIndexByGroup(op->indices[0]); + // invariance: local index must do not contain warp id + ICHECK(!UsesVar(local_index, [this](const VarNode* var) { return var == warp_index_.get(); })) + << "LowerWarpMemory failed to rewrite load to shuffle for index " << op->indices[0] + << " local_index=" << local_index; + + auto writer = load.CopyOnWrite(); + writer->indices = {local_index}; + + if (analyzer_->CanProveEqual(group, warp_index_)) { + return std::move(load); } + + PrimExpr mask = Call(DataType::UInt(32), builtin::tvm_warp_activemask(), {}); + return Call(load.dtype(), builtin::tvm_warp_shuffle(), {mask, load, group, width_, warp_size_}); } + // Split the index to the two component // // local index is the index in the local diff --git a/src/tir/transforms/make_packed_api.cc b/src/tir/transforms/make_packed_api.cc index d7e1beff03d3..a31349fe1c07 100644 --- a/src/tir/transforms/make_packed_api.cc +++ b/src/tir/transforms/make_packed_api.cc @@ -61,34 +61,63 @@ class ReturnRewriter : public StmtMutator { if (call->op.same_as(builtin::ret())) { ICHECK_EQ(in_parallel_, 0) << "tir.ret cannot be used in parallel scope."; ICHECK_EQ(call->args.size(), 1) << "tir.ret expect a single argument."; - ret = WriteToOut(call->args[0], ret_var_, ret_tcode_); + ret = WriteToOut(call->args[0]); } } return ret; } private: - std::pair ConvertForFFI(PrimExpr val) { + struct ConvertedInfo { + int tcode{-1}; + PrimExpr expr; + Buffer dummy_val_buffer; + Buffer dummy_tcode_buffer; + }; + + ConvertedInfo ConvertForFFI(PrimExpr val) { + ConvertedInfo info; + // convert val's data type to FFI data type, return type code DataType dtype = val.dtype(); if (dtype.is_int() || dtype.is_uint()) { - return {kTVMArgInt, Cast(DataType::Int(64), val)}; + info.tcode = kTVMArgInt; + info.expr = Cast(DataType::Int(64), val); } else if (dtype.is_float()) { - return {kTVMArgFloat, Cast(DataType::Float(64), val)}; + info.tcode = kTVMArgFloat; + info.expr = Cast(DataType::Float(64), val); } else if (dtype.is_void()) { - return {kTVMNullptr, val}; + info.tcode = kTVMNullptr; + info.expr = val; } else { LOG(FATAL) << "data type " << dtype << " not supported yet"; } - return {kTVMNullptr, val}; + + // If multiple return locations have the same data type, use the + // same dummy buffer declaration. + auto it = dummy_val_buffer_map_.find(info.tcode); + if (it != dummy_val_buffer_map_.end()) { + info.dummy_val_buffer = it->second; + } else { + info.dummy_val_buffer = Buffer(ret_var_, info.expr.dtype(), {1}, {1}, ConstInt32(0), + ret_var_->name_hint, 0, 0, kDefault); + dummy_val_buffer_map_[info.tcode] = info.dummy_val_buffer; + } + + // The tcode is always a 32-bit int, so we don't need to have a separate map. + if (!dummy_tcode_buffer_.defined()) { + dummy_tcode_buffer_ = Buffer(ret_tcode_, DataType::Int(32), {1}, {1}, ConstInt32(0), + ret_tcode_->name_hint, 0, 0, kDefault); + } + info.dummy_tcode_buffer = dummy_tcode_buffer_; + + return info; } - Stmt WriteToOut(PrimExpr val, Var ret_var, Var ret_tcode) { - auto p = ConvertForFFI(val); - int tcode = p.first; - val = p.second; - Stmt store_val = Store(ret_var_, val, 0, const_true()); - Stmt store_tcode = Store(ret_tcode_, tcode, 0, const_true()); + Stmt WriteToOut(PrimExpr val) { + auto info = ConvertForFFI(val); + Stmt store_val = BufferStore(info.dummy_val_buffer, info.expr, {0}); + Stmt store_tcode = BufferStore(info.dummy_tcode_buffer, info.tcode, {0}); Stmt ret_zero = Evaluate(tvm::ret(0)); return SeqStmt({store_val, store_tcode, ret_zero}); } @@ -96,6 +125,9 @@ class ReturnRewriter : public StmtMutator { Var ret_var_; Var ret_tcode_; int in_parallel_{0}; + + std::unordered_map dummy_val_buffer_map_; + Buffer dummy_tcode_buffer_; }; Stmt RewriteReturn(Stmt body, Var ret_var, Var ret_tcode) { @@ -131,10 +163,11 @@ PrimFunc MakePackedAPI(PrimFunc&& func, int num_unpacked_args) { // Data field definitions // The packed fields Var v_packed_args("args", DataType::Handle()); - Var v_packed_arg_type_ids("arg_type_ids", DataType::Handle()); + Buffer buf_packed_arg_type_ids = decl_buffer({IntImm(DataType::Int(32), func_ptr->params.size())}, + DataType::Int(32), "arg_type_ids"); Var v_num_packed_args("num_args", DataType::Int(32)); - Var v_out_ret_value("out_ret_value", DataType::Handle()); - Var v_out_ret_tcode("out_ret_tcode", DataType::Handle()); + Var v_out_ret_value("out_ret_value", PointerType(PrimType(DataType::Void()))); + Var v_out_ret_tcode("out_ret_tcode", PointerType(PrimType(DataType::Int(32)))); Var v_resource_handle("resource_handle", DataType::Handle()); // The arguments of the function. Array args; @@ -166,7 +199,7 @@ PrimFunc MakePackedAPI(PrimFunc&& func, int num_unpacked_args) { // add signature for packed arguments. if (pack_args) { args.push_back(v_packed_args); - args.push_back(v_packed_arg_type_ids); + args.push_back(buf_packed_arg_type_ids->data); args.push_back(v_num_packed_args); } @@ -185,21 +218,21 @@ PrimFunc MakePackedAPI(PrimFunc&& func, int num_unpacked_args) { continue; } - auto it = func_ptr->buffer_map.find(param); - if (it != func_ptr->buffer_map.end()) { - buffer_def.emplace_back(v_arg, (*it).second); + if (func_ptr->preflattened_buffer_map.count(param)) { + buffer_def.emplace_back(v_arg, func_ptr->preflattened_buffer_map[param]); + } else if (func_ptr->buffer_map.count(param)) { + buffer_def.emplace_back(v_arg, func_ptr->buffer_map[param]); } else { var_def.emplace_back(v_arg, param); } + if (i < num_packed_args) { // Value loads seq_init.emplace_back(LetStmt(v_arg, f_arg_value(v_arg.dtype(), i), nop)); // type code checks Var tcode(v_arg->name_hint + ".code", DataType::Int(32)); - seq_init.emplace_back(LetStmt(tcode, - Load(DataType::Int(32), v_packed_arg_type_ids, - IntImm(DataType::Int(32), i), const_true(1)), - nop)); + seq_init.emplace_back( + LetStmt(tcode, BufferLoad(buf_packed_arg_type_ids, {IntImm(DataType::Int(32), i)}), nop)); DataType t = v_arg.dtype(); if (t.is_handle()) { std::ostringstream msg; diff --git a/src/tir/transforms/merge_dynamic_shared_memory_allocations.cc b/src/tir/transforms/merge_dynamic_shared_memory_allocations.cc index b10e4439b99d..e61af842b507 100644 --- a/src/tir/transforms/merge_dynamic_shared_memory_allocations.cc +++ b/src/tir/transforms/merge_dynamic_shared_memory_allocations.cc @@ -102,12 +102,17 @@ class DynSharedMemLinearAccessPatternFinder final : public StmtExprVisitor { alloc_info_[buf].level = level; StmtExprVisitor::VisitStmt_(op); } + void VisitStmt_(const StoreNode* op) final { + LOG(FATAL) << "Unexpected use of deprecated StoreNode. Please use BufferStoreNode instead."; + } + + void VisitStmt_(const BufferStoreNode* op) final { scope_.push_back(StmtEntry()); // visit subexpr StmtExprVisitor::VisitStmt_(op); // Add write access. - const VarNode* buf = op->buffer_var.get(); + const VarNode* buf = op->buffer->data.get(); auto it = alloc_info_.find(buf); if (it != alloc_info_.end() && it->second.alloc) { ICHECK_LT(it->second.level, scope_.size()); @@ -122,6 +127,7 @@ class DynSharedMemLinearAccessPatternFinder final : public StmtExprVisitor { linear_seq_.push_back(e); } } + void VisitStmt_(const EvaluateNode* op) final { scope_.push_back(StmtEntry()); // visit subexpr @@ -133,10 +139,15 @@ class DynSharedMemLinearAccessPatternFinder final : public StmtExprVisitor { linear_seq_.push_back(e); } } + void VisitExpr_(const LoadNode* op) final { + LOG(FATAL) << "Unexpected use of deprecated LoadNode. Please use BufferLoadNode instead."; + } + + void VisitExpr_(const BufferLoadNode* op) final { // Add write access. StmtExprVisitor::VisitExpr_(op); - const VarNode* buf = op->buffer_var.get(); + const VarNode* buf = op->buffer->data.get(); auto it = alloc_info_.find(buf); if (it != alloc_info_.end() && it->second.alloc) { ICHECK_LT(it->second.level, scope_.size()) << "Load memory in places other than store."; @@ -145,10 +156,13 @@ class DynSharedMemLinearAccessPatternFinder final : public StmtExprVisitor { } } } + void VisitExpr_(const CallNode* op) final { if (op->op.same_as(builtin::address_of())) { - const LoadNode* l = op->args[0].as(); - this->VisitExpr(l->index); + const BufferLoadNode* load = op->args[0].as(); + for (const auto& index : load->indices) { + this->VisitExpr(index); + } } else { StmtExprVisitor::VisitExpr_(op); } @@ -294,22 +308,61 @@ class DynamicSharedMemoryRewriter : public StmtExprMutator { } PrimExpr VisitExpr_(const LoadNode* op) final { - if (IsDynamicSharedMemory(op->buffer_var)) { - PrimExpr offset = GetBufferOffset(op->buffer_var, op->dtype); - PrimExpr index = StmtExprMutator::VisitExpr(op->index); - return Load(op->dtype, merged_buf_var_, offset + index, op->predicate, op->span); - } - return StmtExprMutator::VisitExpr_(op); + LOG(FATAL) << "Unexpected use of deprecated LoadNode. Please use BufferLoadNode instead."; + return PrimExpr(); } Stmt VisitStmt_(const StoreNode* op) final { - if (IsDynamicSharedMemory(op->buffer_var)) { - PrimExpr offset = GetBufferOffset(op->buffer_var, op->value->dtype); - PrimExpr index = StmtExprMutator::VisitExpr(op->index); - PrimExpr value = StmtExprMutator::VisitExpr(op->value); - return Store(merged_buf_var_, value, offset + index, op->predicate, op->span); + LOG(FATAL) << "Unexpected use of deprecated StoreNode. Please use BufferStoreNode instead."; + return Stmt(); + } + + PrimExpr VisitExpr_(const BufferLoadNode* op) final { + auto node = Downcast(StmtExprMutator::VisitExpr_(op)); + return VisitBufferAccess(std::move(node)); + } + + Stmt VisitStmt_(const BufferStoreNode* op) final { + auto node = Downcast(StmtExprMutator::VisitStmt_(op)); + return VisitBufferAccess(std::move(node)); + } + + template + Node VisitBufferAccess(Node node) { + if (IsDynamicSharedMemory(node->buffer->data)) { + ICHECK_EQ(node->indices.size(), 1) + << "MergeDynamicSharedMemoryAllocations expects flat memory buffers, " + << "and is to be run after " + << "StorageFlatten (TE schedules) or FlattenBuffer (TIR schedules)"; + Array indices = {node->indices[0] + + this->GetBufferOffset(node->buffer->data, node->buffer->dtype)}; + + auto writer = node.CopyOnWrite(); + writer->buffer = GetUpdatedBuffer(node->buffer); + writer->indices = indices; } - return StmtExprMutator::VisitStmt_(op); + + return node; + } + + Buffer GetUpdatedBuffer(Buffer buffer) { + auto key = buffer.get(); + auto it = buffer_remap_.find(key); + if (it != buffer_remap_.end()) { + return it->second; + } + + if (IsDynamicSharedMemory(buffer->data)) { + ICHECK_EQ(buffer->shape.size(), 1) + << "MergeDynamicSharedMemoryAllocations expects flat memory buffers, " + << "and is to be run after " + << "StorageFlatten (TE schedules) or FlattenBuffer (TIR schedules)"; + auto writer = buffer.CopyOnWrite(); + writer->data = merged_buf_var_; + } + + buffer_remap_[key] = buffer; + return buffer; } PrimExpr VisitExpr_(const CallNode* op) final { @@ -542,6 +595,8 @@ class DynamicSharedMemoryRewriter : public StmtExprMutator { PrimExpr merged_alloc_size_{0}; // The mapping from the original buffer var to its offset in the merged buffer std::unordered_map buffer_byte_offsets_; + // The mapping from the original buffer objects to their location in the merged buffer. + std::unordered_map buffer_remap_; // The flag indicating whether the merged buffer has been allocated bool allocated_{false}; // Locations of free ops. diff --git a/src/tir/transforms/narrow_datatype.cc b/src/tir/transforms/narrow_datatype.cc index dd5f54e52455..d5d145653fa3 100644 --- a/src/tir/transforms/narrow_datatype.cc +++ b/src/tir/transforms/narrow_datatype.cc @@ -205,12 +205,52 @@ class DataTypeRewriter : public StmtExprMutator { } Stmt VisitStmt_(const StoreNode* op) final { - PrimExpr value = this->VisitExpr(op->value); + LOG(FATAL) << "Unexpected use of deprecated StoreNode. Please use BufferStoreNode instead."; + return Stmt(); + } + + PrimExpr VisitExpr_(const LoadNode* op) final { + LOG(FATAL) << "Unexpected use of deprecated LoadNode. Please use BufferLoadNode instead."; + return PrimExpr(); + } + + Stmt VisitStmt_(const BufferStoreNode* op) final { + BufferStore store = GetRef(op); + + auto value = this->VisitExpr(op->value); + auto indices = VisitIndices(op->indices); + + if (!value.same_as(op->value) || !indices.same_as(op->indices)) { + auto writer = store.CopyOnWrite(); + writer->value = value; + writer->indices = indices; + } + + return std::move(store); + } + + PrimExpr VisitExpr_(const BufferLoadNode* op) final { + BufferLoad load = GetRef(op); + + auto indices = VisitIndices(op->indices); + + if (!indices.same_as(op->indices)) { + auto writer = load.CopyOnWrite(); + writer->indices = indices; + } + + return std::move(load); + } + + Array VisitIndices(Array indices) { is_index_ = true; - PrimExpr index = this->VisitExpr(op->index); + + auto fmutate = [this](const PrimExpr& index) { return this->VisitExpr(index); }; + indices.MutateByApply(fmutate); + is_index_ = false; - Stmt s = Store(op->buffer_var, op->value, index, op->predicate); - return StmtExprMutator::VisitStmt_(s.as()); + + return indices; } Stmt VisitStmt_(const ForNode* op) final { @@ -280,14 +320,6 @@ class DataTypeRewriter : public StmtExprMutator { return StmtExprMutator::VisitExpr_(op); } - PrimExpr VisitExpr_(const LoadNode* op) final { - is_index_ = true; - PrimExpr index = this->VisitExpr(op->index); - is_index_ = false; - PrimExpr e = Load(op->dtype, op->buffer_var, index, op->predicate); - return StmtExprMutator::VisitExpr_(e.as()); - } - PrimExpr VisitExpr_(const IntImmNode* op) final { if (is_index_) { if (visitor_.vmap.find(op) != visitor_.vmap.end()) { diff --git a/src/tir/transforms/rewrite_unsafe_select.cc b/src/tir/transforms/rewrite_unsafe_select.cc index f1286d773c2d..8a37f9958073 100644 --- a/src/tir/transforms/rewrite_unsafe_select.cc +++ b/src/tir/transforms/rewrite_unsafe_select.cc @@ -42,8 +42,13 @@ class UnsafeExprDetector : public ExprFunctor { if (op->op.same_as(builtin::if_then_else())) { return VisitExpr(op->args[0]); } else if (op->op.same_as(builtin::address_of())) { - const LoadNode* l = op->args[0].as(); - return this->VisitExpr(l->index); + const BufferLoadNode* load = op->args[0].as(); + for (const auto& index : load->indices) { + if (VisitExpr(index)) { + return true; + } + } + return false; } else if (auto* ptr_op = op->op.as()) { auto effect_kind = op_call_effect_[GetRef(ptr_op)]; if (effect_kind == CallEffectKind::kPure || effect_kind == CallEffectKind::kExprAnnotation) { @@ -58,10 +63,14 @@ class UnsafeExprDetector : public ExprFunctor { return true; } } - bool VisitExpr_(const LoadNode* op) { + bool VisitExpr_(const BufferLoadNode* op) { // Load is considered unsafe. return true; } + bool VisitExpr_(const LoadNode* op) { + LOG(FATAL) << "Unexpected use of deprecated LoadNode. Please use BufferLoadNode instead."; + return true; + } bool VisitExpr_(const AddNode* op) final { return BinaryOp(op); } bool VisitExpr_(const SubNode* op) final { return BinaryOp(op); } bool VisitExpr_(const MulNode* op) final { return BinaryOp(op); } diff --git a/src/tir/transforms/simplify.cc b/src/tir/transforms/simplify.cc index df8816c8f693..7d4fac8d7b2d 100644 --- a/src/tir/transforms/simplify.cc +++ b/src/tir/transforms/simplify.cc @@ -82,17 +82,37 @@ class StmtSimplifier : public IRMutatorWithAnalyzer { } } - // eliminate useless stores Stmt VisitStmt_(const StoreNode* op) final { - Stmt stmt = Parent::VisitStmt_(op); - op = stmt.as(); - if (const LoadNode* load = op->value.as()) { - if (load->buffer_var.same_as(op->buffer_var) && - tir::ExprDeepEqual()(load->index, op->index)) { + LOG(FATAL) << "Unexpected use of deprecated StoreNode. Please use BufferStoreNode instead."; + return Stmt(); + } + + // eliminate useless stores + Stmt VisitStmt_(const BufferStoreNode* op) final { + BufferStore store = Downcast(Parent::VisitStmt_(op)); + if (const BufferLoadNode* load = op->value.as()) { + if (load->buffer->data.same_as(op->buffer->data) && + ArrayDeepEqual(load->indices, op->indices) && + tir::ExprDeepEqual()(load->buffer->elem_offset, op->buffer->elem_offset) && + ArrayDeepEqual(load->buffer->shape, op->buffer->shape) && + ArrayDeepEqual(load->buffer->strides, op->buffer->strides)) { return Evaluate(0); } } - return GetRef(op); + return std::move(store); + } + + private: + bool ArrayDeepEqual(const Array& lhs, const Array& rhs) { + if (lhs.size() != rhs.size()) { + return false; + } + for (size_t i = 0; i < lhs.size(); i++) { + if (!tir::ExprDeepEqual()(lhs[i], rhs[i])) { + return false; + } + } + return true; } }; diff --git a/src/tir/transforms/split_host_device.cc b/src/tir/transforms/split_host_device.cc index e54aceb16a77..1b8c150079c7 100644 --- a/src/tir/transforms/split_host_device.cc +++ b/src/tir/transforms/split_host_device.cc @@ -112,7 +112,12 @@ class VarUseDefAnalysis : public StmtExprMutator { } Stmt VisitStmt_(const StoreNode* op) final { - this->HandleUse(op->buffer_var); + LOG(FATAL) << "Unexpected use of deprecated StoreNode. Please use BufferStoreNode instead."; + return Stmt(); + } + + Stmt VisitStmt_(const BufferStoreNode* op) final { + VisitBuffer(op->buffer); return StmtExprMutator::VisitStmt_(op); } @@ -160,10 +165,27 @@ class VarUseDefAnalysis : public StmtExprMutator { } PrimExpr VisitExpr_(const LoadNode* op) final { - this->HandleUse(op->buffer_var); + LOG(FATAL) << "Unexpected use of deprecated LoadNode. Please use BufferLoadNode instead."; + return PrimExpr(); + } + + PrimExpr VisitExpr_(const BufferLoadNode* op) final { + VisitBuffer(op->buffer); return StmtExprMutator::VisitExpr_(op); } + void VisitBuffer(Buffer buffer) { + this->HandleUse(buffer->data); + auto visit_arr = [&](Array arr) { + for (const auto& element : arr) { + this->VisitExpr(element); + } + }; + + visit_arr(buffer->shape); + visit_arr(buffer->strides); + } + void HandleDef(const VarNode* v) { ICHECK(!def_count_.count(v)) << "variable " << v->name_hint << " has already been defined, the Stmt is not SSA"; diff --git a/src/tir/transforms/storage_access.cc b/src/tir/transforms/storage_access.cc index 0567c8613fcd..4f19f708880c 100644 --- a/src/tir/transforms/storage_access.cc +++ b/src/tir/transforms/storage_access.cc @@ -34,15 +34,25 @@ namespace tvm { namespace tir { void StorageAccessVisitor::VisitExpr_(const LoadNode* op) { - const VarNode* buf = op->buffer_var.as(); - StorageScope scope = GetScope(op->buffer_var); - if (Enabled(buf, scope)) { + LOG(FATAL) << "Unexpected use of deprecated LoadNode. Please use BufferLoadNode instead."; +} + +void StorageAccessVisitor::VisitStmt_(const StoreNode* op) { + LOG(FATAL) << "Unexpected use of deprecated StoreNode. Please use BufferStoreNode instead."; +} + +void StorageAccessVisitor::VisitExpr_(const BufferLoadNode* op) { + Var buf = op->buffer->data; + StorageScope scope = GetScope(buf); + if (Enabled(buf.get(), scope)) { ICHECK(allow_append_) << op << " " << scope.to_string(); AccessEntry e; e.threads = env_threads(); - e.buffer = op->buffer_var; + e.buffer = buf; e.dtype = op->dtype.element_of(); - e.touched = arith::IntSet::Vector(op->index); + for (const auto& index : op->indices) { + e.touched.push_back(arith::IntSet::Vector(index)); + } e.type = kRead; e.scope = scope; curr_stmt_.access.emplace_back(std::move(e)); @@ -51,18 +61,21 @@ void StorageAccessVisitor::VisitExpr_(const LoadNode* op) { StmtExprVisitor::VisitExpr_(op); } -void StorageAccessVisitor::VisitStmt_(const StoreNode* op) { +void StorageAccessVisitor::VisitStmt_(const BufferStoreNode* op) { allow_append_ = true; ICHECK_EQ(curr_stmt_.access.size(), 0U); curr_stmt_.stmt = op; - const VarNode* buf = op->buffer_var.as(); - StorageScope scope = GetScope(op->buffer_var); - if (Enabled(buf, scope)) { + + Var buf = op->buffer->data; + StorageScope scope = GetScope(buf); + if (Enabled(buf.get(), scope)) { AccessEntry e; e.threads = env_threads(); - e.buffer = op->buffer_var; + e.buffer = buf; e.dtype = op->value.dtype().element_of(); - e.touched = arith::IntSet::Vector(op->index); + for (const auto& index : op->indices) { + e.touched.push_back(arith::IntSet::Vector(index)); + } e.type = kWrite; e.scope = scope; curr_stmt_.access.emplace_back(std::move(e)); @@ -151,8 +164,12 @@ void StorageAccessVisitor::VisitStmt_(const ForNode* op) { arith::IntSet::FromRange(Range::FromMinExtent(op->min, op->extent)); for (AccessEntry& e : s.access) { if (e.buffer.defined()) { - ICHECK(e.touched.defined()); - e.touched = arith::EvalSet(e.touched, relax_map); + ICHECK(e.touched.size()); + Array new_touched; + for (const auto& touched : e.touched) { + new_touched.push_back(arith::EvalSet(touched, relax_map)); + } + e.touched = std::move(new_touched); } } } @@ -196,8 +213,8 @@ void StorageAccessVisitor::VisitStmt_(const WhileNode* op) { void StorageAccessVisitor::VisitExpr_(const CallNode* op) { if (op->op.same_as(builtin::address_of())) { - const LoadNode* l = op->args[0].as(); - StmtExprVisitor::VisitExpr_(l); + const BufferLoadNode* load = op->args[0].as(); + StmtExprVisitor::VisitExpr_(load); } else if (op->op.same_as(builtin::tvm_access_ptr())) { ICHECK_EQ(op->args.size(), 5U); DataType dtype = op->args[0].dtype(); @@ -213,7 +230,7 @@ void StorageAccessVisitor::VisitExpr_(const CallNode* op) { e.threads = env_threads(); e.dtype = dtype; e.buffer = Downcast(op->args[1]); - e.touched = arith::IntSet::FromRange(Range::FromMinExtent(offset, extent)); + e.touched = {arith::IntSet::FromRange(Range::FromMinExtent(offset, extent))}; e.scope = scope; if (flag->value & 1) { e.type = kRead; diff --git a/src/tir/transforms/storage_access.h b/src/tir/transforms/storage_access.h index 9dc4c923b054..a48ee73f17fc 100644 --- a/src/tir/transforms/storage_access.h +++ b/src/tir/transforms/storage_access.h @@ -61,8 +61,11 @@ class StorageAccessVisitor : public StmtExprVisitor { Var buffer = NullValue(); /*! \brief The access data type */ DataType dtype; - /*! \brief The touched access range */ - arith::IntSet touched; + /*! \brief The touched access range + * + * Has one IntSet for each index in the buffer being accessed. + */ + Array touched; /*! \brief The type of access */ AccessType type; /*! \brief The storage scope */ @@ -80,6 +83,8 @@ class StorageAccessVisitor : public StmtExprVisitor { // override visitor pattern void VisitExpr_(const LoadNode* op) final; void VisitStmt_(const StoreNode* op) final; + void VisitExpr_(const BufferLoadNode* op) final; + void VisitStmt_(const BufferStoreNode* op) final; void VisitStmt_(const EvaluateNode* op) final; void VisitStmt_(const AttrStmtNode* op) final; void VisitStmt_(const ForNode* op) final; diff --git a/src/tir/transforms/storage_flatten.cc b/src/tir/transforms/storage_flatten.cc index 783ad13e1ad0..2bc081483ccd 100644 --- a/src/tir/transforms/storage_flatten.cc +++ b/src/tir/transforms/storage_flatten.cc @@ -37,6 +37,7 @@ #include #include +#include #include "../../arith/ir_visitor_with_analyzer.h" #include "../../runtime/thread_storage_scope.h" @@ -163,43 +164,49 @@ class BufferShapeLegalize : public StmtExprMutator { } Stmt VisitStmt_(const BufferStoreNode* op) final { - Stmt stmt = StmtExprMutator::VisitStmt_(op); - op = stmt.as(); - ICHECK(op); + auto node = Downcast(StmtExprMutator::VisitStmt_(op)); + return VisitBufferAccess(std::move(node)); + } + + PrimExpr VisitExpr_(const BufferLoadNode* op) final { + auto node = Downcast(StmtExprMutator::VisitExpr_(op)); + return VisitBufferAccess(std::move(node)); + } - auto it = buf_map_.find(op->buffer); + template + Node VisitBufferAccess(Node node) { + auto it = buf_map_.find(node->buffer); if (it != buf_map_.end()) { const BufferEntry& entry = it->second; - ICHECK(entry.in_scope) << "Cannot store to an out-of-scope buffer"; + ICHECK(entry.in_scope) << "Cannot access an out-of-scope buffer"; - BufferStore updated = GetRef(op); - auto write_ptr = updated.CopyOnWrite(); - write_ptr->indices = update_indices(op->indices, entry.index_offsets); - write_ptr->buffer = entry.remap_to; - stmt = updated; - } + Array indices = node->indices; + if (entry.index_offsets.size()) { + ICHECK_GE(entry.index_offsets.size(), indices.size()) + << "Cannot bind buffer to a shape of lower dimension."; - return stmt; - } + Array new_indices; - PrimExpr VisitExpr_(const BufferLoadNode* op) final { - PrimExpr expr = StmtExprMutator::VisitExpr_(op); - op = expr.as(); - ICHECK(op); + // Pad leading indices with zero, matching the "fuzzy_match" + // behavior from ArgBinder::BindBuffer. + size_t diff = entry.index_offsets.size() - indices.size(); + for (size_t i = 0; i < diff; i++) { + new_indices.push_back(0); + } - auto it = buf_map_.find(op->buffer); - if (it != buf_map_.end()) { - const BufferEntry& entry = it->second; - ICHECK(entry.in_scope) << "Cannot read from an out-of-scope buffer"; + // Offset indices used to access buffers of a reduced size. + for (size_t i = 0; i < indices.size(); i++) { + PrimExpr offset = entry.index_offsets[i + diff]; + new_indices.push_back(indices[i] - offset); + } + indices = new_indices; + } - BufferLoad updated = GetRef(op); - auto write_ptr = updated.CopyOnWrite(); - write_ptr->indices = update_indices(op->indices, entry.index_offsets); + auto write_ptr = node.CopyOnWrite(); + write_ptr->indices = indices; write_ptr->buffer = entry.remap_to; - expr = updated; } - - return expr; + return node; } Stmt VisitStmt_(const AttrStmtNode* op) final { @@ -341,36 +348,6 @@ class BufferShapeLegalize : public StmtExprMutator { return stmt; } - Array update_indices(const Array& indices, const Array& offsets) { - // offsets come from BufferRealizeNode::bounds, which is allowed - // to be empty to indicate realization of the full shape of the - // buffer. In that case, the indices do not need to be modified, - // but may need to be extended with leading zeroes. - if (offsets.size() == 0) { - return indices; - } - - ICHECK_GE(offsets.size(), indices.size()) - << "Cannot bind buffer to a shape of lower dimension."; - - Array new_indices; - - // Pad leading indices with zero, matching the "fuzzy_match" - // behavior from ArgBinder::BindBuffer. - size_t diff = offsets.size() - indices.size(); - for (size_t i = 0; i < diff; i++) { - new_indices.push_back(0); - } - - // Offset indices used to access buffers of a reduced size. - for (size_t i = 0; i < indices.size(); i++) { - PrimExpr offset = offsets[i + diff]; - new_indices.push_back(indices[i] - offset); - } - - return new_indices; - } - std::unordered_map var_remap_; std::unordered_set extern_buffers_; @@ -516,6 +493,19 @@ class BufferStrideLegalize : public StmtExprMutator { } } + // AllocateNodes may be present from tvm.tir.ir_builder. This can + // be simplified in the future by having AllocateNode hold a buffer, + // rather than a buffer_var. + Stmt VisitStmt_(const AllocateNode* op) final { + allocate_node_var_.insert(op->buffer_var.get()); + return StmtExprMutator::VisitStmt_(op); + } + + Stmt VisitStmt_(const AllocateConstNode* op) final { + allocate_node_var_.insert(op->buffer_var.get()); + return StmtExprMutator::VisitStmt_(op); + } + Stmt VisitStmt_(const BufferRealizeNode* op) final { Buffer key = op->buffer; Buffer with_strides = WithStrides(op->buffer); @@ -536,28 +526,36 @@ class BufferStrideLegalize : public StmtExprMutator { return BufferRealize(with_strides, op->bounds, op->condition, op->body, op->span); } - PrimExpr VisitExpr_(const BufferLoadNode* op) final { - PrimExpr expr = StmtExprMutator::VisitExpr_(op); - op = expr.as(); - - auto it = buf_map_.find(op->buffer); - ICHECK(it != buf_map_.end()) << "Cannot find allocated buffer for " << op->buffer; - const BufferEntry& e = it->second; - ICHECK(e.in_scope) << "Cannot read a buffer that is already out of scope"; + Stmt VisitStmt_(const BufferStoreNode* op) final { + auto node = Downcast(StmtExprMutator::VisitStmt_(op)); + return VisitBufferAccess(std::move(node)); + } - return BufferLoad(e.remap_to, op->indices, op->span); + PrimExpr VisitExpr_(const BufferLoadNode* op) final { + auto node = Downcast(StmtExprMutator::VisitExpr_(op)); + return VisitBufferAccess(std::move(node)); } - Stmt VisitStmt_(const BufferStoreNode* op) final { - Stmt stmt = StmtExprMutator::VisitStmt_(op); - op = stmt.as(); + template + Node VisitBufferAccess(Node node) { + auto alloc_key = node->buffer->data.get(); + if (!buf_map_.count(node->buffer) && allocate_node_var_.count(alloc_key)) { + BufferEntry entry; + entry.remap_to = WithStrides(node->buffer); + entry.in_scope = true; + entry.is_external = false; + buf_map_[node->buffer] = entry; + } - auto it = buf_map_.find(op->buffer); - ICHECK(it != buf_map_.end()) << "Cannot find allocated buffer for " << op->buffer; + auto it = buf_map_.find(node->buffer); + ICHECK(it != buf_map_.end()) << "Cannot find allocated buffer for " << node->buffer; const BufferEntry& e = it->second; - ICHECK(e.in_scope) << "Cannot write to a buffer that is already out of scope"; + ICHECK(e.in_scope) << "Cannot access a buffer " << node->buffer->name << ", out of scope"; - return BufferStore(e.remap_to, op->value, op->indices, op->span); + auto writer = node.CopyOnWrite(); + writer->buffer = e.remap_to; + + return node; } private: @@ -579,6 +577,10 @@ class BufferStrideLegalize : public StmtExprMutator { std::unordered_map buf_map_; + // Set of vars that have occurred in an AllocateNode, but haven't + // yet occurred in a BufferLoad/BufferStore. + std::unordered_set allocate_node_var_; + IRVisitorWithAnalyzer* bound_analyzer_; }; @@ -778,39 +780,13 @@ class BufferBindUnwrapper : public StmtExprMutator { } Stmt VisitStmt_(const StoreNode* op) final { - Stmt stmt = StmtExprMutator::VisitStmt_(op); - op = stmt.as(); - auto it = var_remap_.find(op->buffer_var.get()); - if (it != var_remap_.end() && !it->second.same_as(op->buffer_var)) { - // TODO(Lunderberg): Change from warning to error once all mixed - // use of physical/logical layouts is removed. - DLOG(WARNING) << op->buffer_var << " was declared as buffer (buffer_bind_scope), " - << "but is accessed as a pointer (StoreNode)."; - - ICHECK(it->second.as()); - Var new_buf_var = Downcast(it->second); - return Store(new_buf_var, op->value, op->index, op->predicate); - } else { - return stmt; - } + LOG(FATAL) << "Unexpected use of deprecated StoreNode. Please use BufferStoreNode instead."; + return Stmt(); } PrimExpr VisitExpr_(const LoadNode* op) final { - PrimExpr expr = StmtExprMutator::VisitExpr_(op); - op = expr.as(); - auto it = var_remap_.find(op->buffer_var.get()); - if (it != var_remap_.end() && !it->second.same_as(op->buffer_var)) { - // TODO(Lunderberg): Change from warning to error once all mixed - // use of physical/logical layouts is removed. - DLOG(WARNING) << op->buffer_var << " was declared as buffer (buffer_bind_scope), " - << "but is accessed as a pointer (LoadNode)."; - - ICHECK(it->second.as()); - Var new_buf_var = Downcast(it->second); - return Load(op->dtype, new_buf_var, op->index, op->predicate); - } else { - return expr; - } + LOG(FATAL) << "Unexpected use of deprecated LoadNode. Please use BufferLoadNode instead."; + return PrimExpr(); } Stmt VisitStmt_(const AttrStmtNode* op) final { @@ -868,14 +844,24 @@ class BufferBindUnwrapper : public StmtExprMutator { return out; } + // AllocateNodes may be present from tvm.tir.ir_builder. This can + // be simplified in the future by having AllocateNode hold a buffer, + // rather than a buffer_var. + Stmt VisitStmt_(const AllocateNode* op) final { + allocate_node_var_.insert(op->buffer_var.get()); + return StmtExprMutator::VisitStmt_(op); + } + + Stmt VisitStmt_(const AllocateConstNode* op) final { + allocate_node_var_.insert(op->buffer_var.get()); + return StmtExprMutator::VisitStmt_(op); + } + PrimExpr VisitExpr_(const BufferLoadNode* op) final { PrimExpr expr = StmtExprMutator::VisitExpr_(op); op = expr.as(); - auto it = buf_map_.find(op->buffer.get()); - ICHECK(it != buf_map_.end()) << "Cannot find allocated buffer for " << op->buffer; - const BufferEntry& e = it->second; - ICHECK(e.in_scope) << "Cannot read from buffer " << op->buffer << ", out of scope."; + const BufferEntry& e = GetBufferEntry(op->buffer); if (e.remap) { return BufferLoad(e.remap->target, @@ -889,10 +875,7 @@ class BufferBindUnwrapper : public StmtExprMutator { Stmt stmt = StmtExprMutator::VisitStmt_(op); op = stmt.as(); - auto it = buf_map_.find(op->buffer.get()); - ICHECK(it != buf_map_.end()) << "Cannot find allocated buffer for " << op->buffer; - const BufferEntry& e = it->second; - ICHECK(e.in_scope) << "Cannot write to buffer" << op->buffer << ", out of scope."; + const BufferEntry& e = GetBufferEntry(op->buffer); if (e.remap) { return BufferStore(e.remap->target, op->value, @@ -933,10 +916,7 @@ class BufferBindUnwrapper : public StmtExprMutator { op = stmt.as(); ICHECK(op != nullptr); - const auto& key = op->buffer.get(); - auto it = buf_map_.find(key); - ICHECK(it != buf_map_.end()) << "Cannot find allocated buffer for " << key; - const BufferEntry& e = it->second; + const BufferEntry& e = GetBufferEntry(op->buffer); ICHECK(e.in_scope) << "Read a buffer that is already out of scope"; ICHECK_EQ(e.buffer->shape.size(), op->bounds.size()) @@ -1066,16 +1046,145 @@ class BufferBindUnwrapper : public StmtExprMutator { std::unique_ptr remap{nullptr}; }; + const BufferEntry& GetBufferEntry(Buffer buffer) { + auto alloc_key = buffer->data.get(); + if (!buf_map_.count(buffer.get()) && allocate_node_var_.count(alloc_key)) { + BufferEntry entry; + entry.buffer = buffer; + buf_map_[buffer.get()] = std::move(entry); + } + + auto it = buf_map_.find(buffer.get()); + ICHECK(it != buf_map_.end()) << "Cannot find allocated buffer for " << buffer; + const BufferEntry& e = it->second; + ICHECK(e.in_scope) << "Cannot access a buffer " << buffer->name << ", out of scope"; + return it->second; + } + // The buffer assignment map // Variable remap std::unordered_map var_remap_; // Buffer map std::unordered_map buf_map_; + // Set of vars that have occurred in an AllocateNode, but haven't + // yet occurred in a BufferLoad/BufferStore. + std::unordered_set allocate_node_var_; // Analyzer for the variable bounds, used to simplify the bounds populator. We really need the // analyzer from it. However IRVisitorWithAnalyzer* bound_analyzer_; }; +class ApplyLayoutTransforms : public StmtExprMutator { + public: + static transform::Pass Pass() { + auto pass_func = [](PrimFunc func, IRModule m, transform::PassContext ctx) { + auto lookup = func->attrs.GetAttr>>("layout_transform_map"); + + if (!lookup) { + return func; + } + + Map> layout_transforms = lookup.value(); + + auto fptr = func.CopyOnWrite(); + + auto mutator = ApplyLayoutTransforms(layout_transforms); + fptr->buffer_map = mutator.UpdateExternBufferMap(fptr->buffer_map); + fptr->body = mutator(std::move(fptr->body)); + + return WithoutAttr(std::move(func), "layout_transform_map"); + }; + return transform::CreatePrimFuncPass(pass_func, 0, "tir.ApplyLayoutTransforms", {}); + } + + explicit ApplyLayoutTransforms(Map> layout_transforms) + : layout_transforms_(layout_transforms) {} + + Map UpdateExternBufferMap(const Map& buffer_map) { + Map output; + for (const auto& kv : buffer_map) { + output.Set(kv.first, GetBufferRemap(kv.second, true)); + } + return output; + } + + Stmt VisitStmt_(const BufferRealizeNode* op) final { + // Call once so that load/store nodes can read from the cached + // value. + GetBufferRemap(op->buffer, true); + + auto realize = Downcast(StmtExprMutator::VisitStmt_(op)); + + auto lookup = layout_transforms_.Get(op->buffer); + if (lookup) { + auto write_ptr = realize.CopyOnWrite(); + write_ptr->buffer = GetBufferRemap(op->buffer, true); + + Array transforms = lookup.value(); + for (const auto& transform : transforms) { + write_ptr->bounds = transform->MapRanges(realize->bounds); + } + } + + return std::move(realize); + } + + Stmt VisitStmt_(const BufferStoreNode* op) final { + auto node = Downcast(StmtExprMutator::VisitStmt_(op)); + return VisitBufferAccess(std::move(node)); + } + + PrimExpr VisitExpr_(const BufferLoadNode* op) final { + auto node = Downcast(StmtExprMutator::VisitExpr_(op)); + return VisitBufferAccess(std::move(node)); + } + + template + Node VisitBufferAccess(Node node) { + auto lookup = layout_transforms_.Get(node->buffer); + if (lookup) { + auto write_ptr = node.CopyOnWrite(); + + write_ptr->buffer = GetBufferRemap(node->buffer); + + Array transforms = lookup.value(); + for (const auto& transform : transforms) { + write_ptr->indices = transform->MapIndices(node->indices); + } + } + return node; + } + + private: + //! \brief Given a buffer, return the buffer it should be remapped into. + Buffer GetBufferRemap(Buffer buf, bool allow_alloc = false) { + auto key = buf.get(); + auto it = buf_map_.find(key); + if (it != buf_map_.end()) { + return it->second; + } + + ICHECK(allow_alloc) << "Buffer " << buf << " accessed before declaration."; + + auto lookup = layout_transforms_.Get(buf); + if (lookup) { + Array transforms = lookup.value(); + + auto write_ptr = buf.CopyOnWrite(); + for (const auto& transform : transforms) { + write_ptr->shape = transform->MapShape(buf->shape); + } + } + + buf_map_[key] = buf; + return buf; + } + + std::unordered_map buf_map_; + + Map> layout_transforms_; +}; + class StorageFlattener : public StmtExprMutator { public: static transform::Pass Pass(int cache_line_size, bool create_bound_attributes) { @@ -1084,9 +1193,16 @@ class StorageFlattener : public StmtExprMutator { bound_analyzer(func->body); + auto pass = StorageFlattener(func->buffer_map, cache_line_size, create_bound_attributes, + &bound_analyzer); + + Map preflattened_buffer_map = + Merge(func->buffer_map, func->preflattened_buffer_map); + auto fptr = func.CopyOnWrite(); - fptr->body = StorageFlattener(fptr->buffer_map, cache_line_size, create_bound_attributes, - &bound_analyzer)(std::move(fptr->body)); + fptr->body = pass(std::move(fptr->body)); + fptr->preflattened_buffer_map = preflattened_buffer_map; + fptr->buffer_map = pass.UpdatedBufferMap(); return func; }; return transform::CreatePrimFuncPass(pass_func, 0, "tir.StorageFlattener", {}); @@ -1098,23 +1214,39 @@ class StorageFlattener : public StmtExprMutator { for (auto kv : extern_buffer_map) { BufferEntry e; e.buffer = kv.second; + e.flattened_buffer = e.buffer.GetFlattenedBuffer(); + // TODO(Lunderberg): Move the handling of boolean into a + // dedicated pass. + + // Boolean tensors are backed by a Int8 array. + if (e.buffer->dtype == DataType::Bool()) { + { + auto writer = e.buffer.CopyOnWrite(); + writer->dtype = DataType::Int(8); + } + { + auto writer = e.flattened_buffer.CopyOnWrite(); + writer->dtype = DataType::Int(8); + } + } e.external = true; buf_map_[kv.second] = e; + + updated_extern_buffer_map_.Set(kv.first, e.flattened_buffer); } cache_line_size_ = cache_line_size; } + Map UpdatedBufferMap() { return updated_extern_buffer_map_; } + Stmt VisitStmt_(const StoreNode* op) final { - Stmt stmt = StmtExprMutator::VisitStmt_(op); - op = stmt.as(); - auto it = var_remap_.find(op->buffer_var.get()); - if (it != var_remap_.end() && !it->second.same_as(op->buffer_var)) { - ICHECK(it->second.as()); - Var buf_var = Downcast(it->second); - return Store(buf_var, op->value, op->index, op->predicate); - } else { - return stmt; - } + LOG(FATAL) << "Unexpected use of deprecated StoreNode. Please use BufferStoreNode instead."; + return Stmt(); + } + + PrimExpr VisitExpr_(const LoadNode* op) final { + LOG(FATAL) << "Unexpected use of deprecated LoadNode. Please use BufferLoadNode instead."; + return PrimExpr(); } Stmt VisitStmt_(const AttrStmtNode* op) final { @@ -1130,9 +1262,8 @@ class StorageFlattener : public StmtExprMutator { if (op->attr_key == attr::double_buffer_scope && op->node->IsInstance()) { auto buffer = Downcast(op->node); Stmt body = this->VisitStmt(op->body); - auto it = buf_map_.find(buffer); - ICHECK(it != buf_map_.end()) << "Cannot find allocated buffer for " << buffer; - body = AttrStmt(it->second.buffer->data, op->attr_key, op->value, std::move(body)); + const auto& entry = GetBufferEntry(buffer); + body = AttrStmt(entry.flattened_buffer->data, op->attr_key, op->value, std::move(body)); return body; } return StmtExprMutator::VisitStmt_(op); @@ -1143,15 +1274,21 @@ class StorageFlattener : public StmtExprMutator { Stmt stmt = StmtExprMutator::VisitStmt_(op); op = stmt.as(); - const auto& key = op->buffer; + const BufferEntry& e = GetBufferEntry(op->buffer); - auto it = buf_map_.find(key); - ICHECK(it != buf_map_.end()) << "Cannot find allocated buffer for " << key; + // Handle casts from the value's dtype to the dtype of the backing + // array. + PrimExpr value = op->value; + if (value.dtype() == DataType::Bool()) { + ICHECK_EQ(e.flattened_buffer->dtype, DataType::Int(8)) + << "Expected int8 backing array for boolean tensor, but received " + << e.flattened_buffer->dtype; + value = tir::Cast(DataType::Int(8), value); + } - const BufferEntry& e = it->second; - ICHECK(e.in_scope) << "Cannot write to " << op->buffer << ", out of scope."; + auto flattened_indices = e.buffer->ElemOffset(op->indices); - Stmt body = e.buffer.vstore(op->indices, op->value); + Stmt body = BufferStore(e.flattened_buffer, value, flattened_indices, op->span); if (create_bound_attributes_ && ShapeIsValid(e.buffer->shape)) { shape_collector_.push_back(std::make_pair(e.buffer->data, e.buffer->shape)); } @@ -1165,6 +1302,19 @@ class StorageFlattener : public StmtExprMutator { return body; } + // AllocateNodes may be present from tvm.tir.ir_builder. This can + // be simplified in the future by having AllocateNode hold a buffer, + // rather than a buffer_var. + Stmt VisitStmt_(const AllocateNode* op) final { + allocate_node_var_.insert(op->buffer_var.get()); + return StmtExprMutator::VisitStmt_(op); + } + + Stmt VisitStmt_(const AllocateConstNode* op) final { + allocate_node_var_.insert(op->buffer_var.get()); + return StmtExprMutator::VisitStmt_(op); + } + Stmt VisitStmt_(const BufferRealizeNode* op) final { const auto& key = op->buffer; @@ -1191,12 +1341,11 @@ class StorageFlattener : public StmtExprMutator { "Please run BufferShapeLegalize first."; } - Array shape = op->buffer->shape; StorageScope skey = StorageScope::Create(GetPtrStorageScope(op->buffer->data)); // use small alignment for small arrays auto dtype = op->buffer->dtype; - size_t const_size = AllocateNode::ConstantAllocationSize(shape); + size_t const_size = AllocateNode::ConstantAllocationSize(op->buffer->shape); int align = GetTempAllocaAlignment(dtype, const_size); if (skey.tag.length() != 0) { MemoryInfo info = GetMemoryInfo(skey.to_string()); @@ -1206,35 +1355,27 @@ class StorageFlattener : public StmtExprMutator { << "Allocation exceed bound of memory tag " << skey.to_string(); } } - Array strides = op->buffer->strides; - e.buffer = Buffer(op->buffer->data, op->buffer->dtype, shape, strides, PrimExpr(), - op->buffer->name, align, 0, kDefault); + e.buffer = Buffer(op->buffer->data, op->buffer->dtype, op->buffer->shape, op->buffer->strides, + PrimExpr(), op->buffer->name, align, 0, kDefault); + e.flattened_buffer = e.buffer.GetFlattenedBuffer(); + + // TODO(Lunderberg): Move the handling of boolean into a + // dedicated pass. + + // Boolean tensors are backed by a Int8 array. + if (e.flattened_buffer->dtype == DataType::Bool()) { + auto writer = e.flattened_buffer.CopyOnWrite(); + writer->dtype = DataType::Int(8); + } buf_map_[key] = e; Stmt body = this->VisitStmt(op->body); buf_map_[key].in_scope = false; - Stmt ret; - DataType storage_type = e.buffer->dtype; - // specially handle bool, lower its storage - // type to beDataType::Int(8)(byte) - if (storage_type == DataType::Bool()) { - storage_type = DataType::Int(8); - } - if (strides.size() != 0) { - int first_dim = 0; - ret = Allocate(e.buffer->data, storage_type, - {e.buffer->strides[first_dim] * e.buffer->shape[first_dim]}, - make_const(DataType::Bool(e.buffer->dtype.lanes()), true), body); - } else { - shape = e.buffer->shape; - if (shape.size() == 0) { - shape.push_back(make_const(DataType::Int(32), 1)); - } - ret = Allocate(e.buffer->data, storage_type, shape, - make_const(DataType::Bool(e.buffer->dtype.lanes()), true), body); - } + Stmt ret = + Allocate(e.flattened_buffer->data, e.flattened_buffer->dtype, e.flattened_buffer->shape, + make_const(DataType::Bool(e.flattened_buffer->dtype.lanes()), true), body); if (create_bound_attributes_ && ShapeIsValid(e.buffer->shape)) { ret = AttrStmt(e.buffer->data, tir::attr::buffer_bound, @@ -1244,19 +1385,6 @@ class StorageFlattener : public StmtExprMutator { } } - PrimExpr VisitExpr_(const LoadNode* op) final { - PrimExpr expr = StmtExprMutator::VisitExpr_(op); - op = expr.as(); - auto it = var_remap_.find(op->buffer_var.get()); - if (it != var_remap_.end() && !it->second.same_as(op->buffer_var)) { - ICHECK(it->second.as()); - Var buf_var = Downcast(it->second); - return Load(op->dtype, buf_var, op->index, op->predicate); - } else { - return expr; - } - } - PrimExpr VisitExpr_(const VarNode* op) final { auto it = var_remap_.find(op); if (it != var_remap_.end()) { @@ -1270,17 +1398,23 @@ class StorageFlattener : public StmtExprMutator { PrimExpr expr = StmtExprMutator::VisitExpr_(op); op = expr.as(); - const auto& key = op->buffer; - - auto it = buf_map_.find(key); - ICHECK(it != buf_map_.end()) << "Cannot find allocated buffer for " << key; - const BufferEntry& e = it->second; - ICHECK(e.in_scope) << "Cannot read to " << op->buffer << ", out of scope."; + const BufferEntry& e = GetBufferEntry(op->buffer); if (create_bound_attributes_ && ShapeIsValid(e.buffer->shape)) { shape_collector_.push_back(std::make_pair(e.buffer->data, e.buffer->shape)); } - return e.buffer.vload(op->indices, e.buffer->dtype); + + auto flattened_indices = e.buffer->ElemOffset(op->indices); + PrimExpr val = BufferLoad(e.flattened_buffer, flattened_indices, op->span); + + if (op->dtype == DataType::Bool()) { + ICHECK_EQ(e.flattened_buffer->dtype, DataType::Int(8)) + << "Expected int8 backing array for boolean tensor, but received " + << e.flattened_buffer->dtype; + val = tir::Cast(DataType::Bool(), val); + } + + return val; } Stmt VisitStmt_(const PrefetchNode* op) final { @@ -1288,10 +1422,7 @@ class StorageFlattener : public StmtExprMutator { op = stmt.as(); ICHECK(op != nullptr); - const auto& key = op->buffer; - auto it = buf_map_.find(key); - ICHECK(it != buf_map_.end()) << "Cannot find allocated buffer for " << key; - const BufferEntry& e = it->second; + const BufferEntry& e = GetBufferEntry(op->buffer); ICHECK(e.in_scope) << "Cannot prefetch " << op->buffer << ", out of scope."; ICHECK_EQ(e.buffer->shape.size(), op->bounds.size()) @@ -1363,8 +1494,10 @@ class StorageFlattener : public StmtExprMutator { }; // The buffer entry in the flatten map struct BufferEntry { - // the buffer of storage + // The buffer object Buffer buffer; + // The updated buffer object, after flattening has been applied. + Buffer flattened_buffer; // Whether the buffer is external bool external{false}; // Whether the buffer is currently in scope. @@ -1389,14 +1522,42 @@ class StorageFlattener : public StmtExprMutator { for (size_t i = 1; i < shape.size(); ++i) { bound = Mul(bound, Mul(make_const(bound.dtype(), type.lanes()), shape[i])); } - return bound; + Array bounds{bound}; + + return Call(DataType::Handle(), builtin::tvm_tuple(), bounds); + } + + const BufferEntry& GetBufferEntry(Buffer buffer) { + auto alloc_key = buffer->data.get(); + if (!buf_map_.count(buffer) && allocate_node_var_.count(alloc_key)) { + BufferEntry entry; + entry.buffer = buffer; + entry.flattened_buffer = buffer.GetFlattenedBuffer(); + // Boolean tensors are backed by a Int8 array. + if (entry.flattened_buffer->dtype == DataType::Bool()) { + auto writer = entry.flattened_buffer.CopyOnWrite(); + writer->dtype = DataType::Int(8); + } + buf_map_[buffer] = std::move(entry); + } + + auto it = buf_map_.find(buffer); + ICHECK(it != buf_map_.end()) << "Cannot find allocated buffer for " << buffer; + const BufferEntry& e = it->second; + ICHECK(e.in_scope) << "Cannot access a buffer " << buffer->name << ", out of scope"; + return it->second; } // The buffer assignment map // Variable remap std::unordered_map var_remap_; + // Set of vars that have occurred in an AllocateNode, but haven't + // yet occurred in a BufferLoad/BufferStore. + std::unordered_set allocate_node_var_; // Buffer map std::unordered_map buf_map_; + // The extern buffer map, updated to include flattened buffers. + Map updated_extern_buffer_map_; // Collects shapes. std::vector>> shape_collector_; // bounds populator. We really need the analyzer from it. @@ -1495,6 +1656,7 @@ PrimFunc StorageFlatten(PrimFunc func, int cache_line_size, bool create_bound_at BufferStrideLegalize::Pass(), ThreadScopePropagate::Pass(), BufferBindUnwrapper::Pass(), + ApplyLayoutTransforms::Pass(), StorageFlattener::Pass(cache_line_size, create_bound_attributes), AssertSimplifier::Pass(), }, @@ -1510,6 +1672,9 @@ PrimFunc StorageFlatten(PrimFunc func, int cache_line_size, bool create_bound_at namespace transform { +TVM_REGISTER_GLOBAL("tir.transform.ApplyLayoutTransforms") + .set_body_typed(ApplyLayoutTransforms::Pass); + // TODO(tvm-team): consolidate configs to the PassContext Pass StorageFlatten(int cache_line_size, bool create_bound_attributes) { auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) { diff --git a/src/tir/transforms/storage_rewrite.cc b/src/tir/transforms/storage_rewrite.cc index 9d90e0b3f226..6e8e824c5fa2 100644 --- a/src/tir/transforms/storage_rewrite.cc +++ b/src/tir/transforms/storage_rewrite.cc @@ -89,12 +89,17 @@ class LinearAccessPatternFinder final : public StmtExprVisitor { alloc_info_[buf].level = level; StmtExprVisitor::VisitStmt_(op); } + void VisitStmt_(const StoreNode* op) final { + LOG(FATAL) << "Unexpected use of deprecated StoreNode. Please use BufferStoreNode instead."; + } + + void VisitStmt_(const BufferStoreNode* op) final { scope_.push_back(StmtEntry()); // visit subexpr StmtExprVisitor::VisitStmt_(op); // Add write access. - const VarNode* buf = op->buffer_var.get(); + const VarNode* buf = op->buffer->data.get(); auto it = alloc_info_.find(buf); if (it != alloc_info_.end() && it->second.alloc) { ICHECK_LT(it->second.level, scope_.size()); @@ -107,6 +112,22 @@ class LinearAccessPatternFinder final : public StmtExprVisitor { linear_seq_.push_back(e); } } + + void VisitExpr_(const LoadNode* op) final { + LOG(FATAL) << "Unexpected use of deprecated LoadNode. Please use BufferLoadNode instead."; + } + + void VisitExpr_(const BufferLoadNode* op) final { + // Add write access. + StmtExprVisitor::VisitExpr_(op); + const VarNode* buf = op->buffer->data.get(); + auto it = alloc_info_.find(buf); + if (it != alloc_info_.end() && it->second.alloc) { + ICHECK_LT(it->second.level, scope_.size()) << "Load memory in places other than store."; + scope_[it->second.level].touched.push_back(buf); + } + } + void VisitStmt_(const EvaluateNode* op) final { scope_.push_back(StmtEntry()); // visit subexpr @@ -118,24 +139,18 @@ class LinearAccessPatternFinder final : public StmtExprVisitor { linear_seq_.push_back(e); } } - void VisitExpr_(const LoadNode* op) final { - // Add write access. - StmtExprVisitor::VisitExpr_(op); - const VarNode* buf = op->buffer_var.get(); - auto it = alloc_info_.find(buf); - if (it != alloc_info_.end() && it->second.alloc) { - ICHECK_LT(it->second.level, scope_.size()) << "Load memory in places other than store."; - scope_[it->second.level].touched.push_back(buf); - } - } + void VisitExpr_(const CallNode* op) final { if (op->op.same_as(builtin::address_of())) { - const LoadNode* l = op->args[0].as(); - this->VisitExpr(l->index); + const BufferLoadNode* load = op->args[0].as(); + for (const auto& index : load->indices) { + this->VisitExpr(index); + } } else { StmtExprVisitor::VisitExpr_(op); } } + void VisitExpr_(const VarNode* buf) final { // Directly reference to the variable count as a read. auto it = alloc_info_.find(buf); @@ -144,6 +159,7 @@ class LinearAccessPatternFinder final : public StmtExprVisitor { scope_[it->second.level].touched.push_back(buf); } } + template void VisitNewScope(const T* op) { scope_.push_back(StmtEntry()); @@ -164,6 +180,7 @@ class LinearAccessPatternFinder final : public StmtExprVisitor { ICHECK_NE(end_index, 0U); linear_seq_[begin_index].scope_pair_offset = end_index - begin_index; } + void VisitStmt_(const AttrStmtNode* op) final { // Only record the outer most thread extent. if (op->attr_key == attr::thread_extent && !in_thread_env_) { @@ -178,6 +195,7 @@ class LinearAccessPatternFinder final : public StmtExprVisitor { StmtExprVisitor::VisitStmt_(op); } } + void VisitStmt_(const IfThenElseNode* op) final { VisitNewScope(op); } void VisitStmt_(const ForNode* op) final { VisitNewScope(op); } @@ -240,6 +258,8 @@ class InplaceOpVerifier : public StmtExprVisitor { VisitStmt_(static_cast(stmt)); } else if (stmt->IsInstance()) { VisitStmt_(static_cast(stmt)); + } else if (stmt->IsInstance()) { + VisitStmt_(static_cast(stmt)); } else { return false; } @@ -266,17 +286,21 @@ class InplaceOpVerifier : public StmtExprVisitor { } void VisitStmt_(const StoreNode* op) final { + LOG(FATAL) << "Unexpected use of deprecated StoreNode. Please use BufferStoreNode instead."; + } + + void VisitStmt_(const BufferStoreNode* op) final { ++mem_nest_; - this->VisitExpr(op->index); + for (const auto& index : op->indices) { + this->VisitExpr(index); + } --mem_nest_; - if (op->buffer_var.get() == dst_) { + if (op->buffer->data.get() == dst_) { store_ = op; this->VisitExpr(op->value); - this->VisitExpr(op->predicate); store_ = nullptr; } else { this->VisitExpr(op->value); - this->VisitExpr(op->predicate); } } @@ -290,7 +314,11 @@ class InplaceOpVerifier : public StmtExprVisitor { } void VisitExpr_(const LoadNode* op) final { - const VarNode* buf = op->buffer_var.get(); + LOG(FATAL) << "Unexpected use of deprecated LoadNode. Please use BufferLoadNode instead."; + } + + void VisitExpr_(const BufferLoadNode* op) final { + const VarNode* buf = op->buffer->data.get(); // cannot read from dst_ (no reduction) if (buf == dst_) { result_ = false; @@ -302,11 +330,19 @@ class InplaceOpVerifier : public StmtExprVisitor { return; } if (src_ == buf) { - if (store_ == nullptr || store_->value.dtype() != op->dtype || - !tir::ExprDeepEqual()(store_->index, op->index)) { + if (store_ == nullptr || store_->value.dtype() != op->dtype) { result_ = false; return; } + ICHECK_EQ(store_->indices.size(), op->indices.size()) + << "Store/Load occur to the same buffer " << buf->name_hint + << " with differing number of indices"; + for (size_t i = 0; i < store_->indices.size(); i++) { + if (!tir::ExprDeepEqual()(store_->indices[i], op->indices[i])) { + result_ = false; + return; + } + } } ++mem_nest_; StmtExprVisitor::VisitExpr_(op); @@ -324,7 +360,7 @@ class InplaceOpVerifier : public StmtExprVisitor { // it is not safe to inplace when there is nested load like A[B[i]] int mem_nest_{0}; // The current store to be inspected - const StoreNode* store_{nullptr}; + const BufferStoreNode* store_{nullptr}; }; /* \brief Rewrite and merge memory allocation. @@ -355,22 +391,62 @@ class StoragePlanRewriter : public StmtExprMutator { } return stmt; } + Stmt VisitStmt_(const StoreNode* op) final { - Stmt stmt = StmtExprMutator::VisitStmt_(op); - op = stmt.as(); - auto it = alloc_map_.find(op->buffer_var.get()); - if (it == alloc_map_.end()) return stmt; - return Store(it->second->alloc_var, op->value, - RemapIndex(op->value.dtype(), op->index, it->second), op->predicate); + LOG(FATAL) << "Unexpected use of deprecated StoreNode. Please use BufferStoreNode instead."; + return Stmt(); } + PrimExpr VisitExpr_(const LoadNode* op) final { - PrimExpr expr = StmtExprMutator::VisitExpr_(op); - op = expr.as(); - auto it = alloc_map_.find(op->buffer_var.get()); - if (it == alloc_map_.end()) return expr; - return Load(op->dtype, it->second->alloc_var, RemapIndex(op->dtype, op->index, it->second), - op->predicate); + LOG(FATAL) << "Unexpected use of deprecated LoadNode. Please use BufferLoadNode instead."; + return PrimExpr(); + } + + template + Node VisitBufferAccess(Node node) { + auto it = alloc_map_.find(node->buffer->data.get()); + if (it != alloc_map_.end()) { + Buffer buf = RemapBuffer(node->buffer, it->second->alloc_var); + + Array indices = node->indices; + indices.Set(indices.size() - 1, + RemapIndex(node->buffer->dtype, indices[indices.size() - 1], it->second)); + + auto writer = node.CopyOnWrite(); + writer->buffer = buf; + writer->indices = indices; + } + return node; + } + + Buffer RemapBuffer(Buffer buf, Var new_backing_array) { + auto key = buf.get(); + auto it = buffer_remap_.find(key); + if (it != buffer_remap_.end()) { + ICHECK_EQ(it->second->data.get(), new_backing_array.get()) + << "Cannot remap buffer " << buf->name << " to use backing array " + << new_backing_array->name_hint << ", previously used backing array " + << it->second->data->name_hint; + return it->second; + } + + Buffer remapped = Buffer(new_backing_array, buf->dtype, buf->shape, buf->strides, + buf->elem_offset, new_backing_array->name_hint, buf->data_alignment, + buf->offset_factor, buf->buffer_type, buf->axis_separators, buf->span); + buffer_remap_[key] = remapped; + return remapped; + } + + Stmt VisitStmt_(const BufferStoreNode* op) final { + auto node = Downcast(StmtExprMutator::VisitStmt_(op)); + return VisitBufferAccess(std::move(node)); } + + PrimExpr VisitExpr_(const BufferLoadNode* op) final { + auto node = Downcast(StmtExprMutator::VisitExpr_(op)); + return VisitBufferAccess(std::move(node)); + } + PrimExpr VisitExpr_(const VarNode* op) final { auto it = alloc_map_.find(op); if (it != alloc_map_.end()) { @@ -890,6 +966,8 @@ class StoragePlanRewriter : public StmtExprMutator { std::unordered_map alloc_map_; // The allocations std::vector > alloc_vec_; + // The buffer objects being remapped + std::unordered_map buffer_remap_; // analyzer arith::Analyzer analyzer_; }; @@ -902,7 +980,8 @@ struct BufferVarInfo { kPrimFuncParam = (1 << 0), kPrimFuncBufferMap = (1 << 1), kAllocateNode = (1 << 2), - kLetNode = (1 << 3), + kAllocateConstNode = (1 << 3), + kLetNode = (1 << 4), }; // The tir::Var that represents this buffer. @@ -1006,20 +1085,29 @@ class VectorTypeAccessChecker : public StmtExprVisitor { } void VisitExpr_(const LoadNode* op) final { - OnArrayAccess(op->dtype, op->buffer_var.get(), op->index, op->predicate); - StmtExprVisitor::VisitExpr_(op); + LOG(FATAL) << "Unexpected use of deprecated LoadNode. Please use BufferLoadNode instead."; } void VisitStmt_(const StoreNode* op) final { - OnArrayAccess(op->value.dtype(), op->buffer_var.get(), op->index, op->predicate); + LOG(FATAL) << "Unexpected use of deprecated StoreNode. Please use BufferStoreNode instead."; + } + + void VisitExpr_(const BufferLoadNode* op) final { + OnArrayAccess(op->dtype, op->buffer->data.get(), op->indices); + StmtExprVisitor::VisitExpr_(op); + } + + void VisitStmt_(const BufferStoreNode* op) final { + OnArrayAccess(op->value.dtype(), op->buffer->data.get(), op->indices); StmtExprVisitor::VisitStmt_(op); } + void VisitExpr_(const CallNode* op) final { if (op->op.same_as(builtin::tvm_access_ptr())) { DataType dtype = op->args[0].dtype(); const VarNode* buffer = op->args[1].as(); PrimExpr index = op->args[2]; - OnArrayAccess(dtype, buffer, index, const_true(dtype.lanes())); + OnArrayAccess(dtype, buffer, {index}); } StmtExprVisitor::VisitExpr_(op); } @@ -1035,7 +1123,7 @@ class VectorTypeAccessChecker : public StmtExprVisitor { void VisitStmt_(const AllocateConstNode* op) final { const Array& extents = op->extents; PrimExpr extent = extents.size() ? extents[extents.size() - 1] : NullValue(); - OnArrayDeclaration(op->buffer_var, op->dtype, extent, BufferVarInfo::kAllocateNode); + OnArrayDeclaration(op->buffer_var, op->dtype, extent, BufferVarInfo::kAllocateConstNode); StmtExprVisitor::VisitStmt_(op); } @@ -1101,8 +1189,7 @@ class VectorTypeAccessChecker : public StmtExprVisitor { * * @param predicate The predicate used for the store/load. */ - void OnArrayAccess(DataType value_dtype, const VarNode* buffer, const PrimExpr& index, - const PrimExpr& predicate) { + void OnArrayAccess(DataType value_dtype, const VarNode* buffer, const Array& indices) { auto it = info_map_.find(buffer); ICHECK(it != info_map_.end()) << "Load/Store of buffer " << buffer->name_hint << " (" << buffer << ") occurred before its declaration."; @@ -1118,6 +1205,11 @@ class VectorTypeAccessChecker : public StmtExprVisitor { var_info.element_dtype = value_dtype.element_of(); } + int index_lanes = 1; + for (const auto& index : indices) { + index_lanes *= index.dtype().lanes(); + } + DataType access_dtype = value_dtype; int lanes_used = var_info.element_dtype.lanes(); @@ -1128,8 +1220,8 @@ class VectorTypeAccessChecker : public StmtExprVisitor { // necessary because the C-based codegens do not yet support vectorized // pointer types (e.g. float16x4*). Once they do, this if statement should // instead be replaced by the below ICHECK_EQ. - if (index.dtype().lanes() * var_info.element_dtype.lanes() != value_dtype.lanes()) { - ICHECK_EQ(index.dtype().lanes(), value_dtype.lanes()); + if (index_lanes * var_info.element_dtype.lanes() != value_dtype.lanes()) { + ICHECK_EQ(index_lanes, value_dtype.lanes()); lanes_used = 1; var_info.element_dtype = var_info.element_dtype.with_lanes(1); } @@ -1138,22 +1230,24 @@ class VectorTypeAccessChecker : public StmtExprVisitor { // See https://discuss.tvm.apache.org/t/pre-rfc-vectorized-tir-buffers/10615 // for discussion. - // ICHECK_EQ(index.dtype().lanes() * var_info.element_dtype.lanes(), value_dtype.lanes()) + // ICHECK_EQ(index_lanes * var_info.element_dtype.lanes(), value_dtype.lanes()) // << "Attempting to retrieve " << value_dtype.lanes() << " lanes of data with " - // << index.dtype().lanes() << " indices into an array whose elements have " + // << index_lanes << " indices into an array whose elements have " // << var_info.element_dtype.lanes() << " lanes. " - // << "Expected output with " << index.dtype().lanes() * var_info.element_dtype.lanes() + // << "Expected output with " << index_lanes * var_info.element_dtype.lanes() // << " lanes."; // If the index is a RampNode with stride of 1 and offset // divisible by the number of number of lanes, and the predicate // does not apply any masking, then this array access could be // vectorized. - const RampNode* ramp_index = index.as(); - if (ramp_index && is_one(ramp_index->stride) && is_one(predicate)) { - arith::ModularSet me = analyzer_.modular_set(ramp_index->base); - if ((me->coeff % ramp_index->lanes == 0) && (me->base % ramp_index->lanes == 0)) { - lanes_used = ramp_index->lanes; + if (indices.size()) { + const RampNode* ramp_index = indices[indices.size() - 1].as(); + if (ramp_index && is_one(ramp_index->stride)) { + arith::ModularSet me = analyzer_.modular_set(ramp_index->base); + if ((me->coeff % ramp_index->lanes == 0) && (me->base % ramp_index->lanes == 0)) { + lanes_used = ramp_index->lanes; + } } } @@ -1219,7 +1313,7 @@ class VectorTypeRewriter : public StmtExprMutator { VectorTypeRewriter(const std::unordered_map& info_map, bool rewrite_params = true, bool rewrite_buffer_map = true, bool rewrite_allocate_node = true, bool rewrite_indices = true, - bool rewrite_let_node = true) + bool rewrite_let_node = true, bool rewrite_allocate_const_node = true) : rewrite_indices_(rewrite_indices) { int rewrite_mask = 0; if (rewrite_params) { @@ -1234,6 +1328,9 @@ class VectorTypeRewriter : public StmtExprMutator { if (rewrite_let_node) { rewrite_mask |= BufferVarInfo::kLetNode; } + if (rewrite_allocate_const_node) { + rewrite_mask |= BufferVarInfo::kAllocateConstNode; + } // Rewrite any buffer variables whose preferred type isn't their current type. for (const auto& pair : info_map) { @@ -1252,55 +1349,92 @@ class VectorTypeRewriter : public StmtExprMutator { } PrimExpr VisitExpr_(const LoadNode* op) final { - PrimExpr expr = StmtExprMutator::VisitExpr_(op); - op = expr.as(); + LOG(FATAL) << "Unexpected use of deprecated LoadNode. Please use BufferLoadNode instead."; + return PrimExpr(); + } + Stmt VisitStmt_(const StoreNode* op) final { + LOG(FATAL) << "Unexpected use of deprecated StoreNode. Please use BufferStoreNode instead."; + return Stmt(); + } + + template + Node VisitBufferAccess(Node node) { if (!rewrite_indices_) { - return expr; + return node; } - auto it = rewrite_map_.find(op->buffer_var.get()); + auto it = rewrite_map_.find(node->buffer->data.get()); if (it == rewrite_map_.end()) { - return expr; + return node; } const auto& info = it->second; - DataType out_dtype_base = info.new_element_dtype.element_of(); + Array indices = node->indices; - const RampNode* ramp_index = op->index.as(); + const RampNode* ramp_index = indices[indices.size() - 1].as(); if (ramp_index && is_one(ramp_index->stride)) { PrimExpr new_index = ramp_index->base / make_const(ramp_index->base.dtype(), ramp_index->lanes); - return Load(out_dtype_base.with_lanes(op->dtype.lanes()), info.new_buffer_var, new_index, - const_true(new_index.dtype().lanes()), op->span); - } else { - return Load(out_dtype_base, info.new_buffer_var, op->index, op->predicate); + if (ramp_index->lanes != info.factor()) { + new_index = Ramp(new_index, ramp_index->stride, ramp_index->lanes / info.factor(), + ramp_index->span); + } + + indices.Set(indices.size() - 1, new_index); } + + auto writer = node.CopyOnWrite(); + writer->buffer = RemapBuffer(node->buffer); + writer->indices = indices; + + return node; } - Stmt VisitStmt_(const StoreNode* op) final { - Stmt stmt = StmtExprMutator::VisitStmt_(op); - op = stmt.as(); + PrimExpr VisitExpr_(const BufferLoadNode* op) final { + auto node = Downcast(StmtExprMutator::VisitExpr_(op)); + auto modified = VisitBufferAccess(node); - if (!rewrite_indices_) { - return stmt; + // Not needed for BufferStoreNode, so we can't just call + // LegalizeDtype() in VisitBufferAccess. + if (node.same_as(modified)) { + return std::move(node); + } else { + auto writer = modified.CopyOnWrite(); + writer->LegalizeDType(); + return std::move(modified); } + } - auto it = rewrite_map_.find(op->buffer_var.get()); - if (it == rewrite_map_.end()) { - return stmt; + Stmt VisitStmt_(const BufferStoreNode* op) final { + auto node = Downcast(StmtExprMutator::VisitStmt_(op)); + return VisitBufferAccess(std::move(node)); + } + + Buffer RemapBuffer(Buffer buf) { + auto cache_key = buf.get(); + + auto cache_it = buffer_map_.find(cache_key); + if (cache_it != buffer_map_.end()) { + return cache_it->second; } - const auto& info = it->second; - const RampNode* ramp_index = op->index.as(); - if (ramp_index && is_one(ramp_index->stride)) { - PrimExpr new_index = - ramp_index->base / make_const(ramp_index->base.dtype(), ramp_index->lanes); - return Store(info.new_buffer_var, op->value, new_index, const_true(new_index.dtype().lanes()), - op->span); - } else { - return Store(info.new_buffer_var, op->value, op->index, op->predicate, op->span); + auto info_it = rewrite_map_.find(buf->data.get()); + if (info_it != rewrite_map_.end()) { + auto& info = info_it->second; + + Array shape = buf->shape; + PrimExpr last_dim = shape[shape.size() - 1]; + shape.Set(shape.size() - 1, last_dim / make_const(last_dim.dtype(), info.factor())); + + auto writer = buf.CopyOnWrite(); + writer->data = info.new_buffer_var; + writer->dtype = info.new_element_dtype; + writer->shape = shape; } + + buffer_map_[cache_key] = buf; + return buf; } PrimExpr VisitExpr_(const CallNode* op) final { @@ -1324,9 +1458,9 @@ class VectorTypeRewriter : public StmtExprMutator { PrimExpr flag = op->args[4]; PrimExpr e_dtype = tir::TypeAnnotation(info.new_element_dtype); - PrimExpr factor = make_const(extent.dtype(), info.new_element_dtype.lanes()); - extent = extent / factor; - index = index / factor; + int factor = info.factor(); + extent = extent / make_const(extent.dtype(), factor); + index = index / make_const(index.dtype(), factor); Array acc_args{e_dtype, info.new_buffer_var, index, extent, flag}; return Call(info.new_element_dtype, builtin::tvm_access_ptr(), acc_args); @@ -1348,11 +1482,9 @@ class VectorTypeRewriter : public StmtExprMutator { Var new_buffer_var = info.new_buffer_var; - int factor = info.new_element_dtype.lanes() / op->dtype.lanes(); - Array extents = op->extents; - extents.Set(extents.size() - 1, - extents[extents.size() - 1] / make_const(extents[0].dtype(), factor)); + PrimExpr last_extent = extents[extents.size() - 1]; + extents.Set(extents.size() - 1, last_extent / make_const(last_extent.dtype(), info.factor())); return Allocate(new_buffer_var, info.new_element_dtype, extents, op->condition, op->body); } @@ -1384,7 +1516,7 @@ class VectorTypeRewriter : public StmtExprMutator { * * @param func A pointer to the PrimFunc being modified. */ - void Finalize(PrimFunc* func_ptr) const { + void Finalize(PrimFunc* func_ptr) { ICHECK(func_ptr) << "Finalize expects a non-null pointer"; auto& func = *func_ptr; auto* n = func.CopyOnWrite(); @@ -1410,29 +1542,15 @@ class VectorTypeRewriter : public StmtExprMutator { } n->params = new_params; - // Remap the Buffer objects in so that the buffers use the new buffer variables + // Remap the Buffer objects in PrimFunc::buffer_map so that the + // buffers use the new buffer variables Map new_buffer_map; for (const auto& pair : n->buffer_map) { Var key = pair.first; Buffer old_buffer = pair.second; Var old_var = old_buffer->data; - - auto it = rewrite_map_.find(old_var.get()); - if (it == rewrite_map_.end()) { - new_buffer_map.Set(key, old_buffer); - } else { - auto& info = it->second; - int factor = info.new_element_dtype.lanes() / info.old_element_dtype.lanes(); - ICHECK_EQ(factor * info.new_element_dtype.lanes(), info.old_element_dtype.lanes()); - - auto* buffer_cow = old_buffer.CopyOnWrite(); - buffer_cow->data = info.new_buffer_var; - buffer_cow->dtype = info.new_element_dtype; - size_t ndim = buffer_cow->shape.size(); - const auto& last_dim = buffer_cow->shape[ndim - 1]; - buffer_cow->shape.Set(ndim - 1, last_dim / make_const(last_dim.dtype(), factor)); - new_buffer_map.Set(key, old_buffer); - } + Buffer new_buffer = RemapBuffer(old_buffer); + new_buffer_map.Set(key, new_buffer); } n->buffer_map = new_buffer_map; } @@ -1443,10 +1561,18 @@ class VectorTypeRewriter : public StmtExprMutator { Var new_buffer_var; DataType old_element_dtype; DataType new_element_dtype; + + int factor() const { + int old_lanes = old_element_dtype.lanes(); + int new_lanes = new_element_dtype.lanes(); + ICHECK_EQ(new_lanes % old_lanes, 0); + return new_lanes / old_lanes; + } }; bool rewrite_indices_{true}; std::unordered_map rewrite_map_; + std::unordered_map buffer_map_; }; // Rewrite allocates, pointer parameters, and buffer map into vectorized versions @@ -1454,12 +1580,14 @@ class VectorTypeRewriter : public StmtExprMutator { PrimFunc PointerValueTypeRewrite(PrimFunc f, bool allow_untyped_pointers = false, bool rewrite_params = true, bool rewrite_buffer_map = true, bool rewrite_allocate_node = true, bool rewrite_indices = true, - bool rewrite_let_node = true) { + bool rewrite_let_node = true, + bool rewrite_allocate_const_node = true) { VectorTypeAccessChecker checker(f->params, f->buffer_map, allow_untyped_pointers); checker(f->body); VectorTypeRewriter rewriter(checker.info_map_, rewrite_params, rewrite_buffer_map, - rewrite_allocate_node, rewrite_indices, rewrite_let_node); + rewrite_allocate_node, rewrite_indices, rewrite_let_node, + rewrite_allocate_const_node); PrimFuncNode* n = f.CopyOnWrite(); n->body = rewriter(std::move(n->body)); rewriter.Finalize(&f); @@ -1473,7 +1601,13 @@ Pass StorageRewrite() { auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) { auto* n = f.CopyOnWrite(); n->body = StoragePlanRewriter().Rewrite(std::move(n->body), true); - return PointerValueTypeRewrite(std::move(f), true, false, false, true, false, true); + // Parameters may not be rewritten, but internal allocations may. + // Vectorization of AllocateConst is currently disabled, as it has + // indexing issues for types that include padding (e.g. int8x3 + // padded out to 32 bits) would require either rewriting + // AllocateConst::data, or would require the code generators to + // handle vectorized constants. + return PointerValueTypeRewrite(std::move(f), true, false, false, true, true, true, false); }; return CreatePrimFuncPass(pass_func, 0, "tir.StorageRewrite", {}); } diff --git a/src/tir/transforms/thread_storage_sync.cc b/src/tir/transforms/thread_storage_sync.cc index 35e4563b8f58..ce3f8fd3e3ac 100644 --- a/src/tir/transforms/thread_storage_sync.cc +++ b/src/tir/transforms/thread_storage_sync.cc @@ -177,22 +177,54 @@ class ThreadSyncPlanner : public StorageAccessVisitor { private: // find conflicting entry in vec. - bool FindConflict(const std::vector& vec, const AccessEntry& e, bool loop_carry) { - for (const AccessEntry& x : vec) { - if (x.buffer.same_as(e.buffer)) { - // Assumes no race between threads - // Same index value means no conflicts - // TODO(tqchen) more standard set based testing. - if (e.touched.IsSinglePoint() && x.touched.IsSinglePoint()) { - if (ExprDeepEqual()(e.touched.PointValue(), x.touched.PointValue())) continue; - } - if (x.double_buffer_write && e.type == kRead && !loop_carry) continue; + bool FindConflict(const std::vector& prev, const AccessEntry& curr, + bool loop_carry) { + for (const AccessEntry& x : prev) { + if (FindConflict(x, curr, loop_carry)) { return true; } } return false; } + bool FindConflict(const AccessEntry& prev, const AccessEntry& curr, bool loop_carry) { + // Access to different buffers does not conflict. + if (!prev.buffer.same_as(curr.buffer)) { + return false; + } + + // Assumes no race between threads + // Same index value means no conflicts + // TODO(tqchen) more standard set based testing. + bool has_same_index = true; + for (size_t i = 0; i < prev.touched.size(); i++) { + const auto& prev_intset = prev.touched[i]; + const auto& curr_intset = curr.touched[i]; + + bool provably_same_index = + prev_intset.IsSinglePoint() && curr_intset.IsSinglePoint() && + ExprDeepEqual()(prev_intset.PointValue(), curr_intset.PointValue()); + + if (!provably_same_index) { + has_same_index = false; + break; + } + } + if (has_same_index) { + return false; + } + + // If this is a read into a double buffer that was previously + // swapped out, then it doesn't conflict. + if (prev.double_buffer_write && curr.type == kRead && !loop_carry) { + return false; + } + + // If nothing else allows sharing the same buffer, then they are + // in conflict. + return true; + } + private: // synchronization scope StorageScope sync_scope_; @@ -222,16 +254,25 @@ class ThreadSyncInserter : public StmtExprMutator { } } PrimExpr VisitExpr_(const LoadNode* op) final { + LOG(FATAL) << "Unexpected use of deprecated LoadNode. Please use BufferLoadNode instead."; + return PrimExpr(); + } + + Stmt VisitStmt_(const StoreNode* op) final { + LOG(FATAL) << "Unexpected use of deprecated StoreNode. Please use BufferStoreNode instead."; + return Stmt(); + } + PrimExpr VisitExpr_(const BufferLoadNode* op) final { if (sync_scope_.rank == StorageRank::kGlobal && - GetScope(op->buffer_var).rank == StorageRank::kGlobal) { - ++rw_stats_[op->buffer_var].read_count; + GetScope(op->buffer->data).rank == StorageRank::kGlobal) { + ++rw_stats_[op->buffer->data].read_count; } return StmtExprMutator::VisitExpr_(op); } - Stmt VisitStmt_(const StoreNode* op) final { + Stmt VisitStmt_(const BufferStoreNode* op) final { if (sync_scope_.rank == StorageRank::kGlobal && - GetScope(op->buffer_var).rank == StorageRank::kGlobal) { - ++rw_stats_[op->buffer_var].write_count; + GetScope(op->buffer->data).rank == StorageRank::kGlobal) { + ++rw_stats_[op->buffer->data].write_count; } return StmtExprMutator::VisitStmt_(op); } diff --git a/src/tir/transforms/unroll_loop.cc b/src/tir/transforms/unroll_loop.cc index c6e0b5c5f41e..e1d0688ab537 100644 --- a/src/tir/transforms/unroll_loop.cc +++ b/src/tir/transforms/unroll_loop.cc @@ -134,6 +134,11 @@ class LoopUnroller : public StmtExprMutator { } Stmt VisitStmt_(const StoreNode* op) final { + LOG(FATAL) << "Unexpected use of deprecated StoreNode. Please use BufferStoreNode instead."; + return Stmt(); + } + + Stmt VisitStmt_(const BufferStoreNode* op) final { ++step_count_; return StmtExprMutator::VisitStmt_(op); } diff --git a/src/tir/transforms/update_pointer_storage_scope.cc b/src/tir/transforms/update_pointer_storage_scope.cc index 4143577a0b17..69db85eda2df 100644 --- a/src/tir/transforms/update_pointer_storage_scope.cc +++ b/src/tir/transforms/update_pointer_storage_scope.cc @@ -58,22 +58,60 @@ PrimExpr UpdatePointerStorageScope::VisitExpr_(const VarNode* op) { return it->second; } -PrimExpr UpdatePointerStorageScope::VisitExpr_(const LoadNode* op) { - auto remapped = StmtExprMutator::VisitExpr(op->buffer_var); - return Load(op->dtype, Downcast(remapped), StmtExprMutator::VisitExpr(op->index), - StmtExprMutator::VisitExpr(op->predicate)); -} - Stmt UpdatePointerStorageScope::VisitStmt_(const AllocateNode* op) { auto remapped = Downcast(StmtExprMutator::VisitExpr(op->buffer_var)); return Allocate(remapped, op->dtype, op->extents, StmtExprMutator::VisitExpr(op->condition), StmtExprMutator::VisitStmt(op->body)); } +template +Node UpdatePointerStorageScope::UpdateBufferAccess(Node node) { + auto new_buffer = GetUpdatedBuffer(node->buffer); + if (!new_buffer.same_as(node->buffer)) { + auto writer = node.CopyOnWrite(); + writer->buffer = new_buffer; + } + return node; +} + +Buffer UpdatePointerStorageScope::GetUpdatedBuffer(Buffer buf) { + // Use the cached buffer, if it exists. + auto key = buf.get(); + auto it = new_buffer_remap_.find(key); + if (it != new_buffer_remap_.end()) { + return it->second; + } + + // Update the buffer's var, if needed. + auto remapped = Downcast(StmtExprMutator::VisitExpr(buf->data)); + if (!remapped.same_as(buf->data)) { + auto writer = buf.CopyOnWrite(); + writer->data = remapped; + } + + // Update the cache and return + new_buffer_remap_[key] = buf; + return buf; +} + +PrimExpr UpdatePointerStorageScope::VisitExpr_(const LoadNode* op) { + LOG(FATAL) << "Unexpected use of deprecated LoadNode. Please use BufferLoadNode instead."; + return PrimExpr(); +} + +PrimExpr UpdatePointerStorageScope::VisitExpr_(const BufferLoadNode* op) { + auto node = Downcast(StmtExprMutator::VisitExpr_(op)); + return UpdateBufferAccess(node); +} + Stmt UpdatePointerStorageScope::VisitStmt_(const StoreNode* op) { - auto remapped = StmtExprMutator::VisitExpr(op->buffer_var); - return Store(Downcast(remapped), StmtExprMutator::VisitExpr(op->value), - StmtExprMutator::VisitExpr(op->index), StmtExprMutator::VisitExpr(op->predicate)); + LOG(FATAL) << "Unexpected use of deprecated StoreNode. Please use BufferStoreNode instead."; + return Stmt(); +} + +Stmt UpdatePointerStorageScope::VisitStmt_(const BufferStoreNode* op) { + auto node = Downcast(StmtExprMutator::VisitStmt_(op)); + return UpdateBufferAccess(node); } } // namespace tir diff --git a/src/tir/transforms/update_pointer_storage_scope.h b/src/tir/transforms/update_pointer_storage_scope.h index f310194a4a51..d5e492e83389 100644 --- a/src/tir/transforms/update_pointer_storage_scope.h +++ b/src/tir/transforms/update_pointer_storage_scope.h @@ -40,11 +40,19 @@ class UpdatePointerStorageScope : public StmtExprMutator { virtual PrimExpr VisitExpr_(const VarNode*); virtual PrimExpr VisitExpr_(const LoadNode*); + virtual PrimExpr VisitExpr_(const BufferLoadNode*); virtual Stmt VisitStmt_(const AllocateNode*); virtual Stmt VisitStmt_(const StoreNode*); + virtual Stmt VisitStmt_(const BufferStoreNode*); private: + template + Node UpdateBufferAccess(Node node); + + Buffer GetUpdatedBuffer(Buffer buf); + std::unordered_map new_var_remap_; + std::unordered_map new_buffer_remap_; }; } // namespace tir diff --git a/src/tir/transforms/vectorize_loop.cc b/src/tir/transforms/vectorize_loop.cc index 0c9c97af650d..feb396569ff9 100644 --- a/src/tir/transforms/vectorize_loop.cc +++ b/src/tir/transforms/vectorize_loop.cc @@ -62,34 +62,93 @@ class VecAllocAccess : public StmtExprMutator { public: VecAllocAccess(const VarNode* buf, Var var, int var_lanes) : buf_(buf), var_(var), var_lanes_(var_lanes) {} - // Load + PrimExpr VisitExpr_(const LoadNode* op) final { - PrimExpr expr = StmtExprMutator::VisitExpr_(op); - op = expr.as(); - if (op->buffer_var.get() == buf_) { - return Load(op->dtype, op->buffer_var, op->index * var_lanes_ + var_, op->predicate); - } else { - return expr; - } + LOG(FATAL) << "Unexpected use of deprecated LoadNode. Please use BufferLoadNode instead."; + return PrimExpr(); } - // Store + Stmt VisitStmt_(const StoreNode* op) final { - Stmt stmt = StmtExprMutator::VisitStmt_(op); - op = stmt.as(); - if (op->buffer_var.get() == buf_) { - return Store(op->buffer_var, op->value, op->index * var_lanes_ + var_, op->predicate); + LOG(FATAL) << "Unexpected use of deprecated StoreNode. Please use BufferStoreNode instead."; + return Stmt(); + } + + PrimExpr VisitExpr_(const BufferLoadNode* op) final { + auto load = Downcast(StmtExprMutator::VisitExpr_(op)); + return UpdateBufferAccess(load); + } + + Stmt VisitStmt_(const BufferStoreNode* op) final { + auto store = Downcast(StmtExprMutator::VisitStmt_(op)); + return UpdateBufferAccess(store); + } + + private: + template + Node UpdateBufferAccess(Node node) { + // Only update the buffer that's being replaced. + if (node->buffer->data.get() != buf_) { + return node; + } + + // Find/make a Buffer object with the correct updated shape. + Buffer buf; + auto it = buffer_map_.find(node->buffer.get()); + if (it != buffer_map_.end()) { + buf = it->second; } else { - return stmt; + // Extend the least significant dimension by a factor of + // var_lanes_. Typically, this will be a 1-d index into a flat + // memory space. + Array shape = node->buffer->shape; + shape.Set(shape.size() - 1, analyzer_.Simplify(shape[shape.size() - 1] * var_lanes_)); + + // TODO(Lunderberg): Move this pass to be prior to + // StorageFlatten/FlattenBuffer, implement by appending a + // dimension to the buffer. Since it is currently after the + // flattening, the strides are not technically necessary, but + // are updated for consistency. + + // Update strides if defined. + Array strides; + for (size_t i = 0; i < strides.size(); i++) { + PrimExpr stride = strides[i]; + if (i != strides.size() - 1) { + stride *= var_lanes_; + } + strides.push_back(analyzer_.Simplify(stride)); + } + + // Copy everything into the new buffer. + buf = node->buffer; + auto buf_writer = buf.CopyOnWrite(); + buf_writer->shape = shape; + buf_writer->strides = strides; + buffer_map_[buf.get()] = buf; } + + // Extend the last index by the number of lanes in the vectorized + // variable. + Array indices = node->indices; + indices.Set(indices.size() - 1, + analyzer_.Simplify(indices[indices.size() - 1] * var_lanes_ + var_)); + + auto writer = node.CopyOnWrite(); + writer->buffer = buf; + writer->indices = indices; + return node; } - private: // buffer var const VarNode* buf_; + // Updated buffer objects. + std::unordered_map buffer_map_; // variable to be replaced Var var_; // the lanes. int var_lanes_; + // Analyzer for simplifications + arith::Analyzer analyzer_; }; // We use ExprFunctor directly instead of StmtExprMutator @@ -312,15 +371,24 @@ class Vectorizer : public StmtMutator, public ExprFunctorVisitExpr(op->index); - PrimExpr pred = this->VisitExpr(op->predicate); - if (index.same_as(op->index) && pred.same_as(op->predicate)) { - return GetRef(op); - } else { - int lanes = std::max(index.dtype().lanes(), pred.dtype().lanes()); - return Load(op->dtype.with_lanes(lanes), op->buffer_var, BroadcastTo(index, lanes), - BroadcastTo(pred, lanes)); + LOG(FATAL) << "Unexpected use of deprecated LoadNode. Please use BufferLoadNode instead."; + return PrimExpr(); + } + // BufferLoad + PrimExpr VisitExpr_(const BufferLoadNode* op) final { + auto load = GetRef(op); + + auto fmutate = [this](const PrimExpr& index) { return this->VisitExpr(index); }; + Array indices = op->indices; + indices.MutateByApply(fmutate); + + if (!indices.same_as(op->indices)) { + auto writer = load.CopyOnWrite(); + writer->indices = indices; + writer->LegalizeDType(); } + + return std::move(load); } // Let PrimExpr VisitExpr_(const LetNode* op) final { @@ -352,17 +420,50 @@ class Vectorizer : public StmtMutator, public ExprFunctor(op); + + auto fmutate = [this](const PrimExpr& index) { return this->VisitExpr(index); }; + Array indices = op->indices; + indices.MutateByApply(fmutate); + PrimExpr value = this->VisitExpr(op->value); - PrimExpr index = this->VisitExpr(op->index); - PrimExpr pred = this->VisitExpr(op->predicate); - if (value.same_as(op->value) && index.same_as(op->index)) { - return GetRef(op); - } else { - int lanes = std::max(value.dtype().lanes(), index.dtype().lanes()); - lanes = std::max(lanes, pred.dtype().lanes()); - return Store(op->buffer_var, BroadcastTo(value, lanes), BroadcastTo(index, lanes), - BroadcastTo(pred, lanes)); + + if (!indices.same_as(op->indices) || !value.same_as(op->value)) { + // How many lanes of indexing are present in the index and + // buffer element type, excluding the last index. T + int other_index_lanes = op->buffer->dtype.lanes(); + for (size_t i = 0; i < indices.size() - 1; i++) { + other_index_lanes *= indices[i].dtype().lanes(); + } + + // The total number of lanes of indexing, including the last index. + int index_lanes = other_index_lanes * indices[indices.size() - 1].dtype().lanes(); + + // The total number of lanes in this store operation. Either + // the index or the value will be broadcast out to this number + // of lanes, depending on which has more lanes. + int total_lanes = std::max(index_lanes, value.dtype().lanes()); + + ICHECK_EQ(total_lanes % other_index_lanes, 0) + << "When storing to buffer " << op->buffer->name << ", cannot produce " << total_lanes + << " lanes of storage location by changing the last index."; + int last_index_lanes = total_lanes / other_index_lanes; + + // Broadcast the last index such that the total number of index + // lanes matches the desired number. + indices.Set(indices.size() - 1, BroadcastTo(indices[indices.size() - 1], last_index_lanes)); + + auto writer = store.CopyOnWrite(); + writer->indices = indices; + writer->value = BroadcastTo(value, total_lanes); } + + return std::move(store); } // For Stmt VisitStmt_(const ForNode* op) final { @@ -429,23 +530,35 @@ class Vectorizer : public StmtMutator, public ExprFunctorVisitExpr(op->condition); if (condition.dtype().is_vector()) { - LOG(WARNING) << "Cannot handle vector extent in alloc "; + LOG(WARNING) << "Cannot handle vector extent in alloc of " << op->buffer_var->name_hint; return Scalarize(GetRef(op)); } + + // Mutate the extents Array extents; - for (size_t i = 0; i < op->extents.size(); i++) { - PrimExpr new_ext = this->VisitExpr(op->extents[i]); + for (const auto& extent : op->extents) { + PrimExpr new_ext = this->VisitExpr(extent); if (new_ext.dtype().is_vector()) { - LOG(WARNING) << "Cannot handle vector extent in alloc "; + LOG(WARNING) << "Cannot handle vector extent in alloc of " << op->buffer_var->name_hint; return Scalarize(GetRef(op)); } extents.push_back(new_ext); } - // place the vector lanes in least significant dimension. - extents.push_back(var_lanes_); - // rewrite access to buffer internally. + + // TODO(Lunderberg): Move this pass to be prior to + // StorageFlatten/FlattenBuffer. That will allow this pass to be + // implemented as adding a new buffer dimension, which is later + // flattened. + + // Extend the least significant dimension by a factor of + // var_lanes_. Typically, this will be a 1-d index into a flat + // memory space. + extents.Set(extents.size() - 1, extents[extents.size() - 1] * var_lanes_); + + // Rewrite access to the buffer in the body. Stmt body = VecAllocAccess(op->buffer_var.get(), var_, var_lanes_)(op->body); body = this->VisitStmt(body); return Allocate(op->buffer_var, op->dtype, extents, condition, body); diff --git a/src/tir/usmp/analysis/extract_buffer_info.cc b/src/tir/usmp/analysis/extract_buffer_info.cc index bb959647c7f0..6f4642ff1535 100644 --- a/src/tir/usmp/analysis/extract_buffer_info.cc +++ b/src/tir/usmp/analysis/extract_buffer_info.cc @@ -75,8 +75,8 @@ class BufferInfoExtractor : public StmtExprVisitor { void VisitStmt_(const AllocateNode* op) override; void VisitExpr_(const CallNode* op) override; void VisitExpr_(const VarNode* op) override; - void VisitExpr_(const LoadNode* op) override; - void VisitStmt_(const StoreNode* op) override; + void VisitExpr_(const BufferLoadNode* op) override; + void VisitStmt_(const BufferStoreNode* op) override; void VisitStmt_(const ForNode* op) override; void UpdateAliases(const Array& args, const PrimFunc& func); @@ -310,13 +310,13 @@ void BufferInfoExtractor::VisitStmt_(const ForNode* op) { scope_stack_.pop(); } -void BufferInfoExtractor::VisitExpr_(const LoadNode* op) { - this->VisitExpr(op->buffer_var); +void BufferInfoExtractor::VisitExpr_(const BufferLoadNode* op) { + this->VisitExpr(op->buffer->data); StmtExprVisitor::VisitExpr_(op); } -void BufferInfoExtractor::VisitStmt_(const StoreNode* op) { - this->VisitExpr(op->buffer_var); +void BufferInfoExtractor::VisitStmt_(const BufferStoreNode* op) { + this->VisitExpr(op->buffer->data); StmtExprVisitor::VisitStmt_(op); } diff --git a/src/tir/usmp/transform/assign_pool_info.cc b/src/tir/usmp/transform/assign_pool_info.cc index a2304f3b9e3d..930299e4f039 100644 --- a/src/tir/usmp/transform/assign_pool_info.cc +++ b/src/tir/usmp/transform/assign_pool_info.cc @@ -110,8 +110,8 @@ IRModule PoolInfoAssigner::operator()() { if (kv.second->IsInstance()) { func_ = Downcast(kv.second); Stmt body = this->VisitStmt(func_->body); - PrimFunc new_prim_func = - PrimFunc(func_->params, body, func_->ret_type, func_->buffer_map, func_->attrs); + PrimFunc new_prim_func = PrimFunc(func_->params, body, func_->ret_type, func_->buffer_map, + func_->preflattened_buffer_map, func_->attrs); mod_->Update(gv, new_prim_func); } } diff --git a/src/tir/usmp/transform/convert_pool_allocations_to_offsets.cc b/src/tir/usmp/transform/convert_pool_allocations_to_offsets.cc index 6abc48c31be0..b73534090ab5 100644 --- a/src/tir/usmp/transform/convert_pool_allocations_to_offsets.cc +++ b/src/tir/usmp/transform/convert_pool_allocations_to_offsets.cc @@ -89,8 +89,8 @@ class PoolAllocationToOffsetConverter : public StmtExprMutator { private: PrimExpr VisitExpr_(const CallNode* op) override; Stmt VisitStmt_(const AllocateNode* op) override; - PrimExpr VisitExpr_(const LoadNode* op) override; - Stmt VisitStmt_(const StoreNode* op) override; + PrimExpr VisitExpr_(const BufferLoadNode* op) override; + Stmt VisitStmt_(const BufferStoreNode* op) override; /*! \brief This is a structure where the modified function * signature is kept while body of the function is mutated @@ -130,6 +130,10 @@ class PoolAllocationToOffsetConverter : public StmtExprMutator { /*! \brief Obtain a resource handle if its there */ Optional GetResourceHandle(const PrimFunc& func); + /*! \brief Get the Buffer object representing the mapped access into + * the pool. + */ + Buffer GetRemappedBuffer(Buffer buf); /*! \brief The tir::Var map to PoolInfo objects */ Map primfunc_args_to_pool_info_map_; @@ -146,7 +150,15 @@ class PoolAllocationToOffsetConverter : public StmtExprMutator { /*! \brief After mutation, each allocate buffer is replaced with tir::Var that is let bounded * to position from a pool as designated by a PoolAllocation */ - Map allocate_buf_to_let_var_; + Map allocate_var_to_let_var_; + /*! \brief A map from the original buffer object + * + * Each key-value pair in this map satisfies + * ``allocate_buf_to_let_var[key->data] = value->data``. However, + * since more than one `tir::Buffer` may use the same Var, they must + * be tracked separately. + */ + Map original_buf_to_let_buf_; /*! \brief A counter to give references to pools a reproducible unique set of names */ int pool_var_count_ = 0; /*! \brief This toggles to remove non tvmscript printable items for IRModule for unit tests */ @@ -180,12 +192,7 @@ PoolAllocationToOffsetConverter::ScopeInfo PoolAllocationToOffsetConverter::Upda String var_name = pool_ref_name + "_var"; DataType elem_dtype = DataType::UInt(8); Var buffer_var(var_name, PointerType(PrimType(elem_dtype), "global")); - Var pool_var; - if (!emit_tvmscript_printable_) { - pool_var = Var(var_name, PointerType(PrimType(elem_dtype), "global")); - } else { - pool_var = Var(var_name, DataType::Handle(8)); - } + Var pool_var = Var(var_name, PointerType(PrimType(elem_dtype), "global")); si.params.push_back(pool_var); si.pools_to_params.Set(pool_info, pool_var); si.allocated_pool_params.push_back(AllocatedPoolInfo( @@ -216,8 +223,8 @@ PrimFunc PoolAllocationToOffsetConverter::CreatePrimFuncWithPoolParams( if (emit_tvmscript_printable_) { original_attrs = DictAttrs(); } - PrimFunc ret = - PrimFunc(si.params, new_body, original_primfunc->ret_type, si.buffer_map, original_attrs); + PrimFunc ret = PrimFunc(si.params, new_body, original_primfunc->ret_type, si.buffer_map, {}, + original_attrs); if (!emit_tvmscript_printable_) { ret = WithAttr(ret, tvm::attr::kPoolArgs, si.allocated_pool_params); } @@ -255,8 +262,8 @@ Array PoolAllocationToOffsetConverter::ReplaceAllocateArgsWithLetArgs( Array ret; for (const PrimExpr& arg : args) { if (arg->IsInstance() && - allocate_buf_to_let_var_.find(Downcast(arg)) != allocate_buf_to_let_var_.end()) { - ret.push_back(allocate_buf_to_let_var_[Downcast(arg)]); + allocate_var_to_let_var_.find(Downcast(arg)) != allocate_var_to_let_var_.end()) { + ret.push_back(allocate_var_to_let_var_[Downcast(arg)]); } else { ret.push_back(VisitExpr(arg)); } @@ -297,37 +304,65 @@ Stmt PoolAllocationToOffsetConverter::VisitStmt_(const AllocateNode* op) { PoolAllocation pool_allocation = pool_allocations_[GetRef(op)]; Var param = scope_info.pools_to_params[pool_allocation->pool_info]; Buffer buffer_var = scope_info.buffer_map[param]; - Load load_node = - Load(DataType::UInt(8), buffer_var->data, pool_allocation->byte_offset, op->condition); - Call address_of_load = Call(DataType::Handle(8), builtin::address_of(), {load_node}); - Var tir_var; - if (!emit_tvmscript_printable_) { - tir_var = Var(op->buffer_var->name_hint + "_let", op->buffer_var->type_annotation); - } else { - tir_var = Var(op->buffer_var->name_hint + "_let", DataType::Handle(8)); + BufferLoad load_node = BufferLoad(buffer_var, {pool_allocation->byte_offset}); + Call address_of_load = Call(DataType::Handle(), builtin::address_of(), {load_node}); + + Type let_var_type = op->buffer_var->type_annotation; + if (emit_tvmscript_printable_) { + // Strip the storage_scope from the variable type, as TVMScript + // doesn't parsethe scoped pointers (e.g. ``T.Ptr[global T.int32]``) + // correctly. + let_var_type = PointerType(Downcast(let_var_type)->element_type); } - allocate_buf_to_let_var_.Set(op->buffer_var, tir_var); + Var let_var(op->buffer_var->name_hint + "_let", let_var_type); + allocate_var_to_let_var_.Set(op->buffer_var, let_var); Stmt new_body = VisitStmt(op->body); - allocate_buf_to_let_var_.erase(op->buffer_var); - return LetStmt(tir_var, address_of_load, new_body); + allocate_var_to_let_var_.erase(op->buffer_var); + return LetStmt(let_var, address_of_load, new_body); } return StmtExprMutator::VisitStmt_(op); } -Stmt PoolAllocationToOffsetConverter::VisitStmt_(const StoreNode* op) { - if (allocate_buf_to_let_var_.find(op->buffer_var) != allocate_buf_to_let_var_.end()) { - return Store(allocate_buf_to_let_var_[op->buffer_var], VisitExpr(op->value), op->index, - VisitExpr(op->predicate)); +Stmt PoolAllocationToOffsetConverter::VisitStmt_(const BufferStoreNode* op) { + BufferStore store = Downcast(StmtExprMutator::VisitStmt_(op)); + + Buffer remapped = GetRemappedBuffer(store->buffer); + if (!op->buffer.same_as(remapped)) { + store.CopyOnWrite()->buffer = remapped; } - return StmtExprMutator::VisitStmt_(op); + return std::move(store); } -PrimExpr PoolAllocationToOffsetConverter::VisitExpr_(const LoadNode* op) { - if (allocate_buf_to_let_var_.find(op->buffer_var) != allocate_buf_to_let_var_.end()) { - return Load(op->dtype, allocate_buf_to_let_var_[op->buffer_var], op->index, - VisitExpr(op->predicate)); +PrimExpr PoolAllocationToOffsetConverter::VisitExpr_(const BufferLoadNode* op) { + BufferLoad load = Downcast(StmtExprMutator::VisitExpr_(op)); + + Buffer remapped = GetRemappedBuffer(load->buffer); + if (!op->buffer.same_as(remapped)) { + load.CopyOnWrite()->buffer = remapped; } - return StmtExprMutator::VisitExpr_(op); + return std::move(load); +} + +Buffer PoolAllocationToOffsetConverter::GetRemappedBuffer(Buffer original) { + { + auto it = original_buf_to_let_buf_.find(original); + if (it != original_buf_to_let_buf_.end()) { + return (*it).second; + } + } + + Buffer remapped = original; + + auto it = allocate_var_to_let_var_.find(original->data); + if (it != allocate_var_to_let_var_.end()) { + remapped = Buffer((*it).second, original->dtype, original->shape, original->strides, + original->elem_offset, original->name, original->data_alignment, + original->offset_factor, original->buffer_type, original->axis_separators, + original->span); + } + + original_buf_to_let_buf_.Set(original, remapped); + return remapped; } IRModule PoolAllocationToOffsetConverter::operator()() { @@ -340,12 +375,12 @@ IRModule PoolAllocationToOffsetConverter::operator()() { // We dont need attrs of PrimFunc that might include non printable attrs such as target // for unit tests where emit_tvmscript_printable_ is to be used. if (!emit_tvmscript_printable_) { - main_func = - PrimFunc(si.params, main_func_body, main_func->ret_type, si.buffer_map, main_func->attrs); + main_func = PrimFunc(si.params, main_func_body, main_func->ret_type, si.buffer_map, {}, + main_func->attrs); main_func = WithAttr(main_func, tvm::attr::kPoolArgs, si.allocated_pool_params); } else { main_func = - PrimFunc(si.params, main_func_body, main_func->ret_type, si.buffer_map, DictAttrs()); + PrimFunc(si.params, main_func_body, main_func->ret_type, si.buffer_map, {}, DictAttrs()); } module_->Update(gv, main_func); if (!emit_tvmscript_printable_) { diff --git a/tests/cpp/tir_analysis_side_effect.cc b/tests/cpp/tir_analysis_side_effect.cc index a59e4a7f8c05..bd7d7805e7aa 100644 --- a/tests/cpp/tir_analysis_side_effect.cc +++ b/tests/cpp/tir_analysis_side_effect.cc @@ -25,10 +25,9 @@ TEST(SimplePasses, SideEffect) { using namespace tvm; - auto A = tir::Var("A", DataType::Handle()); + auto buf = tir::decl_buffer({16}, DataType::Float(32)); auto i = tir::Var("i", DataType::Int(32)); - ICHECK(tir::SideEffect(tir::Load(DataType::Float(32), A, i, tir::const_true(1))) == - tir::CallEffectKind::kReadState); + ICHECK(tir::SideEffect(tir::BufferLoad(buf, {i})) == tir::CallEffectKind::kReadState); ICHECK(tir::SideEffect(exp(tir::Cast(DataType::Float(32), i + 1))) == tir::CallEffectKind::kPure); ICHECK(tir::SideEffect(tir::Call(DataType::Handle(), tir::builtin::tvm_storage_sync(), {})) == tir::CallEffectKind::kUpdateState); diff --git a/tests/python/contrib/test_ethosu/infra.py b/tests/python/contrib/test_ethosu/infra.py index f46792c1e6e5..ecce814c259c 100644 --- a/tests/python/contrib/test_ethosu/infra.py +++ b/tests/python/contrib/test_ethosu/infra.py @@ -276,8 +276,8 @@ def get_convolutional_args(call, include_buffers=False, remove_constants=False): continue elif isinstance(arg, tvm.tir.expr.IntImm) or isinstance(arg, tvm.tir.expr.FloatImm): conv_args.append(arg.value) - elif isinstance(arg, tvm.tir.expr.Load) and not include_buffers: - conv_args.append(arg.index) + elif isinstance(arg, tvm.tir.expr.BufferLoad) and not include_buffers: + conv_args.append(arg.indices[0]) else: conv_args.append(arg) @@ -428,8 +428,8 @@ def get_pooling_args(call, include_buffers=False): for i, arg in enumerate(args): if isinstance(arg, tvm.tir.expr.IntImm) or isinstance(arg, tvm.tir.expr.FloatImm): pooling_args.append(arg.value) - elif isinstance(arg, tvm.tir.expr.Load) and not include_buffers: - pooling_args.append(arg.index) + elif isinstance(arg, tvm.tir.expr.BufferLoad) and not include_buffers: + pooling_args.append(arg.indices[0]) else: pooling_args.append(arg) @@ -479,8 +479,8 @@ def get_binary_elementwise_args(call, include_buffers=False): for i, arg in enumerate(args): if isinstance(arg, tvm.tir.expr.IntImm) or isinstance(arg, tvm.tir.expr.FloatImm): binary_elementwise_args.append(arg.value) - elif isinstance(arg, tvm.tir.expr.Load) and not include_buffers: - binary_elementwise_args.append(arg.index) + elif isinstance(arg, tvm.tir.expr.BufferLoad) and not include_buffers: + binary_elementwise_args.append(arg.indices[0]) else: binary_elementwise_args.append(arg) diff --git a/tests/python/contrib/test_ethosu/test_encode_constants.py b/tests/python/contrib/test_ethosu/test_encode_constants.py index 315712996ac8..8878e467aad7 100644 --- a/tests/python/contrib/test_ethosu/test_encode_constants.py +++ b/tests/python/contrib/test_ethosu/test_encode_constants.py @@ -34,32 +34,36 @@ @tvm.script.ir_module class WeightStreamOnly: @T.prim_func - def main(placeholder: T.Buffer[(1, 16, 16, 32), "int8"], ethosu_write: T.Buffer[(1, 16, 16, 8), "int8"]) -> None: + def main(placeholder: T.Buffer[(8192,), "int8"], ethosu_write: T.Buffer[(2048,), "int8"]) -> None: # function attr dict T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}) - buffer = T.buffer_var("uint8", "") - buffer_1 = T.buffer_var("uint8", "") - buffer_2 = T.buffer_var("uint8", "") - buffer_3 = T.buffer_var("uint8", "") - buffer_4 = T.buffer_var("uint8", "") - buffer_5 = T.buffer_var("uint8", "") - buffer_6 = T.buffer_var("uint8", "") - buffer_7 = T.buffer_var("uint8", "") + buffer = T.buffer_decl([128], "uint8") + buffer_1 = T.buffer_decl([32], "uint8") + buffer_2 = T.buffer_decl([112], "uint8") + buffer_3 = T.buffer_decl([32], "uint8") + buffer_4 = T.buffer_decl([112], "uint8") + buffer_5 = T.buffer_decl([32], "uint8") + buffer_6 = T.buffer_decl([112], "uint8") + buffer_7 = T.buffer_decl([32], "uint8") + T.preflattened_buffer(placeholder, [1, 16, 16, 32], "int8", data=placeholder.data) + T.preflattened_buffer(ethosu_write, [1, 16, 16, 8], "int8", data=ethosu_write.data) # body - placeholder_global = T.allocate([128], "uint8", "global", annotations={"disable_lower_builtin":True}) - placeholder_d_global = T.allocate([32], "uint8", "global", annotations={"disable_lower_builtin":True}) - T.evaluate(T.call_extern("ethosu_copy", T.load("uint8", buffer, 0), 128, T.load("uint8", placeholder_global, 0), dtype="handle")) - T.evaluate(T.call_extern("ethosu_copy", T.load("uint8", buffer_1, 0), 32, T.load("uint8", placeholder_d_global, 0), dtype="handle")) - T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 32, 16, 0, 16, T.load("int8", placeholder.data, 0), 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 16, 2, 16, 0, 16, T.load("int8", ethosu_write.data, 0), 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, T.load("uint8", placeholder_global, 0), 128, 12, T.load("uint8", placeholder_d_global, 0), 32, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) - T.evaluate(T.call_extern("ethosu_copy", T.load("uint8", buffer_2, 0), 112, T.load("uint8", placeholder_global, 0), dtype="handle")) - T.evaluate(T.call_extern("ethosu_copy", T.load("uint8", buffer_3, 0), 32, T.load("uint8", placeholder_d_global, 0), dtype="handle")) - T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 32, 16, 0, 16, T.load("int8", placeholder.data, 0), 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 16, 2, 16, 0, 16, T.load("int8", ethosu_write.data, 2), 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, T.load("uint8", placeholder_global, 0), 112, 12, T.load("uint8", placeholder_d_global, 0), 32, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) - T.evaluate(T.call_extern("ethosu_copy", T.load("uint8", buffer_4, 0), 112, T.load("uint8", placeholder_global, 0), dtype="handle")) - T.evaluate(T.call_extern("ethosu_copy", T.load("uint8", buffer_5, 0), 32, T.load("uint8", placeholder_d_global, 0), dtype="handle")) - T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 32, 16, 0, 16, T.load("int8", placeholder.data, 0), 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 16, 2, 16, 0, 16, T.load("int8", ethosu_write.data, 4), 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, T.load("uint8", placeholder_global, 0), 112, 12, T.load("uint8", placeholder_d_global, 0), 32, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) - T.evaluate(T.call_extern("ethosu_copy", T.load("uint8", buffer_6, 0), 112, T.load("uint8", placeholder_global, 0), dtype="handle")) - T.evaluate(T.call_extern("ethosu_copy", T.load("uint8", buffer_7, 0), 32, T.load("uint8", placeholder_d_global, 0), dtype="handle")) - T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 32, 16, 0, 16, T.load("int8", placeholder.data, 0), 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 16, 2, 16, 0, 16, T.load("int8", ethosu_write.data, 6), 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, T.load("uint8", placeholder_global, 0), 112, 12, T.load("uint8", placeholder_d_global, 0), 32, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) + p1_global = T.allocate([128], "uint8", "global", annotations={"disable_lower_builtin":True}) + p2_global = T.allocate([32], "uint8", "global", annotations={"disable_lower_builtin":True}) + p1_global_1 = T.buffer_decl([112], dtype="uint8", data=p1_global.data) + p2_global_1 = T.buffer_decl([32], dtype="uint8", data=p2_global.data) + T.evaluate(T.call_extern("ethosu_copy", buffer[0], 128, p1_global[0], dtype="handle")) + T.evaluate(T.call_extern("ethosu_copy", buffer_1[0], 32, p2_global[0], dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 32, 16, 0, 16, placeholder[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 16, 2, 16, 0, 16, ethosu_write[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, p1_global[0], 128, 12, p2_global[0], 32, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) + T.evaluate(T.call_extern("ethosu_copy", buffer_2[0], 112, p1_global_1[0], dtype="handle")) + T.evaluate(T.call_extern("ethosu_copy", buffer_3[0], 32, p2_global_1[0], dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 32, 16, 0, 16, placeholder[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 16, 2, 16, 0, 16, ethosu_write[2], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, p1_global_1[0], 112, 12, p2_global_1[0], 32, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) + T.evaluate(T.call_extern("ethosu_copy", buffer_4[0], 112, p1_global_1[0], dtype="handle")) + T.evaluate(T.call_extern("ethosu_copy", buffer_5[0], 32, p2_global_1[0], dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 32, 16, 0, 16, placeholder[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 16, 2, 16, 0, 16, ethosu_write[4], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, p1_global_1[0], 112, 12, p2_global_1[0], 32, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) + T.evaluate(T.call_extern("ethosu_copy", buffer_6[0], 112, p1_global_1[0], dtype="handle")) + T.evaluate(T.call_extern("ethosu_copy", buffer_7[0], 32, p2_global_1[0], dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 32, 16, 0, 16, placeholder[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 16, 2, 16, 0, 16, ethosu_write[6], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, p1_global_1[0], 112, 12, p2_global_1[0], 32, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) __tvm_meta__ = None # fmt: on @@ -107,20 +111,22 @@ def _get_func(): @tvm.script.ir_module class RereadWeights: @T.prim_func - def main(placeholder: T.Buffer[(1, 16, 16, 32), "int8"], ethosu_write: T.Buffer[(1, 16, 16, 8), "int8"]) -> None: + def main(placeholder: T.Buffer[(8192,), "int8"], ethosu_write: T.Buffer[(2048,), "int8"]) -> None: # function attr dict T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}) - buffer = T.buffer_var("uint8", "") - buffer_1 = T.buffer_var("uint8", "") + buffer = T.buffer_decl([304], "uint8") + buffer_1 = T.buffer_decl([80], "uint8") + T.preflattened_buffer(placeholder, [1, 16, 16, 32], "int8", data=placeholder.data) + T.preflattened_buffer(ethosu_write, [1, 16, 16, 8], "int8", data=ethosu_write.data) # body placeholder_global = T.allocate([304], "uint8", "global", annotations={"disable_lower_builtin":True}) placeholder_d_global = T.allocate([80], "uint8", "global", annotations={"disable_lower_builtin":True}) - T.evaluate(T.call_extern("ethosu_copy", T.load("uint8", buffer, 0), 304, T.load("uint8", placeholder_global, 0), dtype="handle")) - T.evaluate(T.call_extern("ethosu_copy", T.load("uint8", buffer_1, 0), 80, T.load("uint8", placeholder_d_global, 0), dtype="handle")) - T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 8, 32, 16, 0, 8, T.load("int8", placeholder.data, 0), 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 8, 8, 16, 0, 8, T.load("int8", ethosu_write.data, 0), 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 1, 8, 1, 1, 1, 1, 1, 1, T.load("uint8", placeholder_global, 0), 304, 12, T.load("uint8", placeholder_d_global, 0), 80, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) - T.evaluate(T.call_extern("ethosu_copy", T.load("uint8", buffer, 0), 304, T.load("uint8", placeholder_global, 0), dtype="handle")) - T.evaluate(T.call_extern("ethosu_copy", T.load("uint8", buffer_1, 0), 80, T.load("uint8", placeholder_d_global, 0), dtype="handle")) - T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 8, 32, 16, 0, 8, T.load("int8", placeholder.data, 256), 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 8, 8, 16, 0, 8, T.load("int8", ethosu_write.data, 64), 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 1, 8, 1, 1, 1, 1, 1, 1, T.load("uint8", placeholder_global, 0), 304, 12, T.load("uint8", placeholder_d_global, 0), 80, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) + T.evaluate(T.call_extern("ethosu_copy", buffer[0], 304, placeholder_global[0], dtype="handle")) + T.evaluate(T.call_extern("ethosu_copy", buffer_1[0], 80, placeholder_d_global[0], dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 8, 32, 16, 0, 8, placeholder[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 8, 8, 16, 0, 8, ethosu_write[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 1, 8, 1, 1, 1, 1, 1, 1, placeholder_global[0], 304, 12, placeholder_d_global[0], 80, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) + T.evaluate(T.call_extern("ethosu_copy", buffer[0], 304, placeholder_global[0], dtype="handle")) + T.evaluate(T.call_extern("ethosu_copy", buffer_1[0], 80, placeholder_d_global[0], dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 8, 32, 16, 0, 8, placeholder[256], 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 8, 8, 16, 0, 8, ethosu_write[64], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 1, 8, 1, 1, 1, 1, 1, 1, placeholder_global[0], 304, 12, placeholder_d_global[0], 80, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) __tvm_meta__ = None # fmt: on @@ -168,17 +174,19 @@ def _get_func(): @tvm.script.ir_module class DirectReadOnly: @T.prim_func - def main(placeholder: T.Buffer[(1, 16, 16, 32), "int8"], ethosu_write: T.Buffer[(1, 16, 16, 8), "int8"]) -> None: + def main(placeholder: T.Buffer[(8192,), "int8"], ethosu_write: T.Buffer[(2048,), "int8"]) -> None: # function attr dict T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}) - buffer = T.buffer_var("uint8", "") - buffer_1 = T.buffer_var("uint8", "") - buffer_2 = T.buffer_var("uint8", "") - buffer_3 = T.buffer_var("uint8", "") + buffer = T.buffer_decl([592], "uint8") + buffer_1 = T.buffer_decl([160], "uint8") + buffer_2 = T.buffer_decl([160], "uint8") + buffer_3 = T.buffer_decl([80], "uint8") + T.preflattened_buffer(placeholder, [1, 16, 16, 32], "int8", data=placeholder.data) + T.preflattened_buffer(ethosu_write, [1, 16, 16, 8], "int8", data=ethosu_write.data) # body ethosu_write_1 = T.allocate([4096], "int8", "global", annotations={"disable_lower_builtin":True}) - T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 32, 16, 0, 16, T.load("int8", placeholder.data, 0), 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 16, 16, 16, 0, 16, T.load("int8", ethosu_write_1, 0), 0, 0, 0, T.float32(0.25), 14, "NHWC", 256, 16, 1, 1, 1, 1, 1, 1, 1, T.load("uint8", buffer, 0), 592, 12, T.load("uint8", buffer_1, 0), 160, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) - T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 16, 16, 0, 16, T.load("int8", ethosu_write_1, 0), 0, 0, 0, T.float32(0.5), 10, "NHWC", 256, 16, 1, "int8", 16, 16, 8, 16, 0, 16, T.load("int8", ethosu_write.data, 0), 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, T.load("uint8", buffer_2, 0), 160, 12, T.load("uint8", buffer_3, 0), 80, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 32, 16, 0, 16, placeholder[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 16, 16, 16, 0, 16, ethosu_write_1[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 256, 16, 1, 1, 1, 1, 1, 1, 1, buffer[0], 592, 12, buffer_1[0], 160, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 16, 16, 0, 16, ethosu_write_1[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 256, 16, 1, "int8", 16, 16, 8, 16, 0, 16, ethosu_write[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, buffer_2[0], 160, 12, buffer_3[0], 80, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) __tvm_meta__ = None # fmt: on @@ -225,36 +233,38 @@ def _get_func(): @tvm.script.ir_module class MixedRead: @T.prim_func - def main(placeholder: T.Buffer[(1, 16, 16, 32), "int8"], ethosu_write: T.Buffer[(1, 16, 16, 8), "int8"]) -> None: + def main(placeholder: T.Buffer[(8192,), "int8"], ethosu_write: T.Buffer[(2048,), "int8"]) -> None: # function attr dict T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}) - buffer = T.buffer_var("uint8", "") - buffer_1 = T.buffer_var("uint8", "") - buffer_2 = T.buffer_var("uint8", "") - buffer_3 = T.buffer_var("uint8", "") - buffer_4 = T.buffer_var("uint8", "") - buffer_5 = T.buffer_var("uint8", "") - buffer_6 = T.buffer_var("uint8", "") - buffer_7 = T.buffer_var("uint8", "") - buffer_8 = T.buffer_var("uint8", "") - buffer_9 = T.buffer_var("uint8", "") + buffer = T.buffer_decl([592], "uint8") + buffer_1 = T.buffer_decl([160], "uint8") + buffer_2 = T.buffer_decl([80], "uint8") + buffer_3 = T.buffer_decl([32], "uint8") + buffer_4 = T.buffer_decl([80], "uint8") + buffer_5 = T.buffer_decl([32], "uint8") + buffer_6 = T.buffer_decl([80], "uint8") + buffer_7 = T.buffer_decl([32], "uint8") + buffer_8 = T.buffer_decl([80], "uint8") + buffer_9 = T.buffer_decl([32], "uint8") + T.preflattened_buffer(placeholder, [1, 16, 16, 32], "int8", data=placeholder.data) + T.preflattened_buffer(ethosu_write, [1, 16, 16, 8], "int8", data=ethosu_write.data) # body ethosu_write_1 = T.allocate([4096], "int8", "global", annotations={"disable_lower_builtin":True}) placeholder_global = T.allocate([80], "uint8", "global", annotations={"disable_lower_builtin":True}) placeholder_d_global = T.allocate([32], "uint8", "global", annotations={"disable_lower_builtin":True}) - T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 32, 16, 0, 16, T.load("int8", placeholder.data, 0), 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 16, 16, 16, 0, 16, T.load("int8", ethosu_write_1, 0), 0, 0, 0, T.float32(0.25), 14, "NHWC", 256, 16, 1, 1, 1, 1, 1, 1, 1, T.load("uint8", buffer, 0), 592, 12, T.load("uint8", buffer_1, 0), 160, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) - T.evaluate(T.call_extern("ethosu_copy", T.load("uint8", buffer_2, 0), 80, T.load("uint8", placeholder_global, 0), dtype="handle")) - T.evaluate(T.call_extern("ethosu_copy", T.load("uint8", buffer_3, 0), 32, T.load("uint8", placeholder_d_global, 0), dtype="handle")) - T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 16, 16, 0, 16, T.load("int8", ethosu_write_1, 0), 0, 0, 0, T.float32(0.5), 10, "NHWC", 256, 16, 1, "int8", 16, 16, 2, 16, 0, 16, T.load("int8", ethosu_write.data, 0), 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, T.load("uint8", placeholder_global, 0), 80, 12, T.load("uint8", placeholder_d_global, 0), 32, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) - T.evaluate(T.call_extern("ethosu_copy", T.load("uint8", buffer_4, 0), 80, T.load("uint8", placeholder_global, 0), dtype="handle")) - T.evaluate(T.call_extern("ethosu_copy", T.load("uint8", buffer_5, 0), 32, T.load("uint8", placeholder_d_global, 0), dtype="handle")) - T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 16, 16, 0, 16, T.load("int8", ethosu_write_1, 0), 0, 0, 0, T.float32(0.5), 10, "NHWC", 256, 16, 1, "int8", 16, 16, 2, 16, 0, 16, T.load("int8", ethosu_write.data, 2), 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, T.load("uint8", placeholder_global, 0), 80, 12, T.load("uint8", placeholder_d_global, 0), 32, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) - T.evaluate(T.call_extern("ethosu_copy", T.load("uint8", buffer_6, 0), 80, T.load("uint8", placeholder_global, 0), dtype="handle")) - T.evaluate(T.call_extern("ethosu_copy", T.load("uint8", buffer_7, 0), 32, T.load("uint8", placeholder_d_global, 0), dtype="handle")) - T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 16, 16, 0, 16, T.load("int8", ethosu_write_1, 0), 0, 0, 0, T.float32(0.5), 10, "NHWC", 256, 16, 1, "int8", 16, 16, 2, 16, 0, 16, T.load("int8", ethosu_write.data, 4), 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, T.load("uint8", placeholder_global, 0), 80, 12, T.load("uint8", placeholder_d_global, 0), 32, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) - T.evaluate(T.call_extern("ethosu_copy", T.load("uint8", buffer_8, 0), 80, T.load("uint8", placeholder_global, 0), dtype="handle")) - T.evaluate(T.call_extern("ethosu_copy", T.load("uint8", buffer_9, 0), 32, T.load("uint8", placeholder_d_global, 0), dtype="handle")) - T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 16, 16, 0, 16, T.load("int8", ethosu_write_1, 0), 0, 0, 0, T.float32(0.5), 10, "NHWC", 256, 16, 1, "int8", 16, 16, 2, 16, 0, 16, T.load("int8", ethosu_write.data, 6), 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, T.load("uint8", placeholder_global, 0), 80, 12, T.load("uint8", placeholder_d_global, 0), 32, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 32, 16, 0, 16, placeholder[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 16, 16, 16, 0, 16, ethosu_write_1[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 256, 16, 1, 1, 1, 1, 1, 1, 1, buffer[0], 592, 12, buffer_1[0], 160, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) + T.evaluate(T.call_extern("ethosu_copy", buffer_2[0], 80, placeholder_global[0], dtype="handle")) + T.evaluate(T.call_extern("ethosu_copy", buffer_3[0], 32, placeholder_d_global[0], dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 16, 16, 0, 16, ethosu_write_1[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 256, 16, 1, "int8", 16, 16, 2, 16, 0, 16, ethosu_write[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, placeholder_global[0], 80, 12, placeholder_d_global[0], 32, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) + T.evaluate(T.call_extern("ethosu_copy", buffer_4[0], 80, placeholder_global[0], dtype="handle")) + T.evaluate(T.call_extern("ethosu_copy", buffer_5[0], 32, placeholder_d_global[0], dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 16, 16, 0, 16, ethosu_write_1[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 256, 16, 1, "int8", 16, 16, 2, 16, 0, 16, ethosu_write[2], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, placeholder_global[0], 80, 12, placeholder_d_global[0], 32, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) + T.evaluate(T.call_extern("ethosu_copy", buffer_6[0], 80, placeholder_global[0], dtype="handle")) + T.evaluate(T.call_extern("ethosu_copy", buffer_7[0], 32, placeholder_d_global[0], dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 16, 16, 0, 16, ethosu_write_1[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 256, 16, 1, "int8", 16, 16, 2, 16, 0, 16, ethosu_write[4], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, placeholder_global[0], 80, 12, placeholder_d_global[0], 32, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) + T.evaluate(T.call_extern("ethosu_copy", buffer_8[0], 80, placeholder_global[0], dtype="handle")) + T.evaluate(T.call_extern("ethosu_copy", buffer_9[0], 32, placeholder_d_global[0], dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 16, 16, 0, 16, ethosu_write_1[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 256, 16, 1, "int8", 16, 16, 2, 16, 0, 16, ethosu_write[6], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, placeholder_global[0], 80, 12, placeholder_d_global[0], 32, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) __tvm_meta__ = None # fmt: on diff --git a/tests/python/contrib/test_ethosu/test_remove_concatenates.py b/tests/python/contrib/test_ethosu/test_remove_concatenates.py index f6e0e2d855cd..f82351c28c05 100644 --- a/tests/python/contrib/test_ethosu/test_remove_concatenates.py +++ b/tests/python/contrib/test_ethosu/test_remove_concatenates.py @@ -30,23 +30,26 @@ @tvm.script.ir_module class ReferenceModule: @T.prim_func - def main(placeholder: T.Buffer[(1, 8, 12, 16), "int8"], placeholder_1: T.Buffer[(1, 8, 10, 16), "int8"], T_concat: T.Buffer[(1, 8, 32, 16), "int8"]) -> None: + def main(placeholder: T.Buffer[(1536,), "int8"], placeholder_1: T.Buffer[(1280,), "int8"], T_concat: T.Buffer[(4096,), "int8"]) -> None: # function attr dict T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}) - buffer = T.buffer_var("uint8", "") - buffer_1 = T.buffer_var("uint8", "") - buffer_2 = T.buffer_var("uint8", "") - buffer_3 = T.buffer_var("uint8", "") - buffer_4 = T.buffer_var("uint8", "") - buffer_5 = T.buffer_var("uint8", "") - buffer_6 = T.buffer_var("uint8", "") - buffer_7 = T.buffer_var("uint8", "") + buffer = T.buffer_decl([2992], "uint8") + buffer_1 = T.buffer_decl([160], "uint8") + buffer_2 = T.buffer_decl([2992], "uint8") + buffer_3 = T.buffer_decl([160], "uint8") + buffer_4 = T.buffer_decl([2992], "uint8") + buffer_5 = T.buffer_decl([160], "uint8") + buffer_6 = T.buffer_decl([2992], "uint8") + buffer_7 = T.buffer_decl([160], "uint8") + T.preflattened_buffer(placeholder, [1, 8, 12, 16], "int8", data=placeholder.data) + T.preflattened_buffer(placeholder_1, [1, 8, 10, 16], "int8", data=placeholder_1.data) + T.preflattened_buffer(T_concat, [1, 8, 32, 16], "int8", data=T_concat.data) # body T_concat_1 = T.allocate([2816], "int8", "global", annotations={"disable_lower_builtin":True}) - T.evaluate(T.call_extern("ethosu_conv2d", "int8", 8, 10, 16, 8, 0, 10, T.load("int8", placeholder_1.data, 0), 0, 0, 0, T.float32(0.5), 10, "NHWC", 160, 16, 1, "int8", 8, 10, 16, 8, 0, 10, T.load("int8", T_concat_1, 192), 0, 0, 0, T.float32(0.25), 14, "NHWC", 352, 16, 1, 3, 3, 1, 1, 1, 1, T.load("uint8", buffer, 0), 2992, 12, T.load("uint8", buffer_1, 0), 160, 1, 1, 1, 1, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) - T.evaluate(T.call_extern("ethosu_conv2d", "int8", 8, 10, 16, 8, 0, 10, T.load("int8", T_concat_1, 192), 0, 0, 0, T.float32(0.5), 10, "NHWC", 352, 16, 1, "int8", 8, 10, 16, 8, 0, 10, T.load("int8", T_concat.data, 352), 0, 0, 0, T.float32(0.25), 14, "NHWC", 512, 16, 1, 3, 3, 1, 1, 1, 1, T.load("uint8", buffer_2, 0), 2992, 12, T.load("uint8", buffer_3, 0), 160, 1, 1, 1, 1, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) - T.evaluate(T.call_extern("ethosu_conv2d", "int8", 8, 12, 16, 8, 0, 12, T.load("int8", placeholder.data, 0), 0, 0, 0, T.float32(0.5), 10, "NHWC", 192, 16, 1, "int8", 8, 12, 16, 8, 0, 12, T.load("int8", T_concat_1, 0), 0, 0, 0, T.float32(0.25), 14, "NHWC", 352, 16, 1, 3, 3, 1, 1, 1, 1, T.load("uint8", buffer_4, 0), 2992, 12, T.load("uint8", buffer_5, 0), 160, 1, 1, 1, 1, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) - T.evaluate(T.call_extern("ethosu_conv2d", "int8", 8, 22, 16, 8, 0, 22, T.load("int8", T_concat_1, 0), 0, 0, 0, T.float32(0.5), 10, "NHWC", 352, 16, 1, "int8", 8, 22, 16, 8, 0, 22, T.load("int8", T_concat.data, 0), 0, 0, 0, T.float32(0.25), 14, "NHWC", 512, 16, 1, 3, 3, 1, 1, 1, 1, T.load("uint8", buffer_6, 0), 2992, 12, T.load("uint8", buffer_7, 0), 160, 1, 1, 1, 1, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 8, 10, 16, 8, 0, 10, placeholder_1[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 160, 16, 1, "int8", 8, 10, 16, 8, 0, 10, T_concat_1[192], 0, 0, 0, T.float32(0.25), 14, "NHWC", 352, 16, 1, 3, 3, 1, 1, 1, 1, buffer[0], 2992, 12, buffer_1[0], 160, 1, 1, 1, 1, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 8, 10, 16, 8, 0, 10, T_concat_1[192], 0, 0, 0, T.float32(0.5), 10, "NHWC", 352, 16, 1, "int8", 8, 10, 16, 8, 0, 10, T_concat[352], 0, 0, 0, T.float32(0.25), 14, "NHWC", 512, 16, 1, 3, 3, 1, 1, 1, 1, buffer_2[0], 2992, 12, buffer_3[0], 160, 1, 1, 1, 1, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 8, 12, 16, 8, 0, 12, placeholder[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 192, 16, 1, "int8", 8, 12, 16, 8, 0, 12, T_concat_1[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 352, 16, 1, 3, 3, 1, 1, 1, 1, buffer_4[0], 2992, 12, buffer_5[0], 160, 1, 1, 1, 1, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 8, 22, 16, 8, 0, 22, T_concat_1[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 352, 16, 1, "int8", 8, 22, 16, 8, 0, 22, T_concat[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 512, 16, 1, 3, 3, 1, 1, 1, 1, buffer_6[0], 2992, 12, buffer_7[0], 160, 1, 1, 1, 1, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) __tvm_meta__ = None # fmt: on diff --git a/tests/python/contrib/test_ethosu/test_replace_conv2d.py b/tests/python/contrib/test_ethosu/test_replace_conv2d.py index 67fb2c760962..5a9aa9855183 100644 --- a/tests/python/contrib/test_ethosu/test_replace_conv2d.py +++ b/tests/python/contrib/test_ethosu/test_replace_conv2d.py @@ -333,114 +333,126 @@ def _visit(stmt): @tvm.script.ir_module class Conv2dDoubleCascade1: @T.prim_func - def main(placeholder_5: T.Buffer[(1, 8, 8, 3), "int8"], ethosu_write_1: T.Buffer[(1, 8, 8, 8), "int8"]) -> None: + def main(placeholder_5: T.Buffer[(192,), "int8"], ethosu_write_1: T.Buffer[(512,), "int8"]) -> None: # function attr dict T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}) - buffer = T.buffer_var("uint8", "") - buffer_1 = T.buffer_var("uint8", "") - buffer_2 = T.buffer_var("uint8", "") - buffer_3 = T.buffer_var("uint8", "") + buffer = T.buffer_decl([304], "uint8") + buffer_1 = T.buffer_decl([80], "uint8") + buffer_2 = T.buffer_decl([320], "uint8") + buffer_3 = T.buffer_decl([160], "uint8") + T.preflattened_buffer(placeholder_5, [1, 8, 8, 3], 'int8', data=placeholder_5.data) + T.preflattened_buffer(ethosu_write_1, [1, 8, 8, 8], 'int8', data=ethosu_write_1.data) # body ethosu_write_2 = T.allocate([1024], "int8", "global", annotations={"disable_lower_builtin": True}) - T.evaluate(T.call_extern("ethosu_conv2d", "int8", 8, 4, 3, 8, 0, 4, T.load("int8", placeholder_5.data, 0), 0, 0, 0, T.float32(0.5), 10, "NHWC", 24, 3, 1, "int8", 8, 4, 32, 8, 0, 4, T.load("int8", ethosu_write_2, 0), 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 32, 1, 1, 1, 1, 1, 1, 1, T.load("uint8", buffer_3, 0), 160, 12, T.load("uint8", buffer_2, 0), 320, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) - T.evaluate(T.call_extern("ethosu_conv2d", "int8", 8, 4, 32, 8, 0, 4, T.load("int8", ethosu_write_2, 0), 0, 0, 0, T.float32(0.5), 10, "NHWC", 128, 32, 1, "int8", 8, 4, 8, 8, 0, 4, T.load("int8", ethosu_write_1.data, 0), 0, 0, 0, T.float32(0.25), 14, "NHWC", 64, 8, 1, 1, 1, 1, 1, 1, 1, T.load("uint8", buffer, 0), 304, 12, T.load("uint8", buffer_1, 0), 80, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) - T.evaluate(T.call_extern("ethosu_conv2d", "int8", 8, 4, 3, 8, 0, 4, T.load("int8", placeholder_5.data, 12), 0, 0, 0, T.float32(0.5), 10, "NHWC", 24, 3, 1, "int8", 8, 4, 32, 8, 0, 4, T.load("int8", ethosu_write_2, 0), 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 32, 1, 1, 1, 1, 1, 1, 1, T.load("uint8", buffer_3, 0), 160, 12, T.load("uint8", buffer_2, 0), 320, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) - T.evaluate(T.call_extern("ethosu_conv2d", "int8", 8, 4, 32, 8, 0, 4, T.load("int8", ethosu_write_2, 0), 0, 0, 0, T.float32(0.5), 10, "NHWC", 128, 32, 1, "int8", 8, 4, 8, 8, 0, 4, T.load("int8", ethosu_write_1.data, 32), 0, 0, 0, T.float32(0.25), 14, "NHWC", 64, 8, 1, 1, 1, 1, 1, 1, 1, T.load("uint8", buffer, 0), 304, 12, T.load("uint8", buffer_1, 0), 80, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 8, 4, 3, 8, 0, 4, placeholder_5[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 24, 3, 1, "int8", 8, 4, 32, 8, 0, 4, ethosu_write_2[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 32, 1, 1, 1, 1, 1, 1, 1, buffer_3[0], 160, 12, buffer_2[0], 320, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 8, 4, 32, 8, 0, 4, ethosu_write_2[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 128, 32, 1, "int8", 8, 4, 8, 8, 0, 4, ethosu_write_1[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 64, 8, 1, 1, 1, 1, 1, 1, 1, buffer[0], 304, 12, buffer_1[0], 80, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 8, 4, 3, 8, 0, 4, placeholder_5[12], 0, 0, 0, T.float32(0.5), 10, "NHWC", 24, 3, 1, "int8", 8, 4, 32, 8, 0, 4, ethosu_write_2[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 32, 1, 1, 1, 1, 1, 1, 1, buffer_3[0], 160, 12, buffer_2[0], 320, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 8, 4, 32, 8, 0, 4, ethosu_write_2[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 128, 32, 1, "int8", 8, 4, 8, 8, 0, 4, ethosu_write_1[32], 0, 0, 0, T.float32(0.25), 14, "NHWC", 64, 8, 1, 1, 1, 1, 1, 1, 1, buffer[0], 304, 12, buffer_1[0], 80, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) __tvm_meta__ = None @tvm.script.ir_module class Conv2dDoubleCascade2: @T.prim_func - def main(placeholder_5: T.Buffer[(1, 8, 8, 3), "int8"], ethosu_write_1: T.Buffer[(1, 8, 8, 8), "int8"]) -> None: + def main(placeholder_5: T.Buffer[(192,), "int8"], ethosu_write_1: T.Buffer[(512,), "int8"]) -> None: # function attr dict T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}) - buffer = T.buffer_var("uint8", "") - buffer_1 = T.buffer_var("uint8", "") - buffer_2 = T.buffer_var("uint8", "") - buffer_3 = T.buffer_var("uint8", "") + buffer = T.buffer_decl([80], "uint8") + buffer_1 = T.buffer_decl([320], "uint8") + buffer_2 = T.buffer_decl([1312], "uint8") + buffer_3 = T.buffer_decl([2608], "uint8") + T.preflattened_buffer(placeholder_5, [1, 8, 8, 3], 'int8', data=placeholder_5.data) + T.preflattened_buffer(ethosu_write_1, [1, 8, 8, 8], 'int8', data=ethosu_write_1.data) # body ethosu_write_2 = T.allocate([1536], "int8", "global", annotations={"disable_lower_builtin": True}) - T.evaluate(T.call_extern("ethosu_conv2d", "int8", 6, 8, 3, 6, 0, 8, T.load("int8", placeholder_5.data, 0), 0, 0, 0, T.float32(0.5), 10, "NHWC", 24, 3, 1, "int8", 5, 8, 32, 5, 0, 8, T.load("int8", ethosu_write_2, 256), 0, 0, 0, T.float32(0.25), 14, "NHWC", 256, 32, 1, 3, 3, 1, 1, 1, 1, T.load("uint8", buffer_2, 0), 1312, 12, T.load("uint8", buffer_1, 0), 320, 1, 1, 0, 1, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) - T.evaluate(T.call_extern("ethosu_conv2d", "int8", 5, 8, 32, 5, 0, 8, T.load("int8", ethosu_write_2, 256), 0, 0, 0, T.float32(0.5), 10, "NHWC", 256, 32, 1, "int8", 4, 8, 8, 4, 0, 8, T.load("int8", ethosu_write_1.data, 0), 0, 0, 0, T.float32(0.25), 14, "NHWC", 64, 8, 1, 3, 3, 1, 1, 1, 1, T.load("uint8", buffer_3, 0), 2608, 12, T.load("uint8", buffer, 0), 80, 1, 1, 0, 1, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) - T.evaluate(T.call_extern("ethosu_conv2d", "int8", 6, 8, 3, 6, 0, 8, T.load("int8", placeholder_5.data, 48), 0, 0, 0, T.float32(0.5), 10, "NHWC", 24, 3, 1, "int8", 5, 8, 32, 5, 0, 8, T.load("int8", ethosu_write_2, 0), 0, 0, 0, T.float32(0.25), 14, "NHWC", 256, 32, 1, 3, 3, 1, 1, 1, 1, T.load("uint8", buffer_2, 0), 1312, 12, T.load("uint8", buffer_1, 0), 320, 0, 1, 1, 1, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) - T.evaluate(T.call_extern("ethosu_conv2d", "int8", 5, 8, 32, 5, 0, 8, T.load("int8", ethosu_write_2, 0), 0, 0, 0, T.float32(0.5), 10, "NHWC", 256, 32, 1, "int8", 4, 8, 8, 4, 0, 8, T.load("int8", ethosu_write_1.data, 256), 0, 0, 0, T.float32(0.25), 14, "NHWC", 64, 8, 1, 3, 3, 1, 1, 1, 1, T.load("uint8", buffer_3, 0), 2608, 12, T.load("uint8", buffer, 0), 80, 0, 1, 1, 1, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 6, 8, 3, 6, 0, 8, placeholder_5[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 24, 3, 1, "int8", 5, 8, 32, 5, 0, 8, ethosu_write_2[256], 0, 0, 0, T.float32(0.25), 14, "NHWC", 256, 32, 1, 3, 3, 1, 1, 1, 1, buffer_2[0], 1312, 12, buffer_1[0], 320, 1, 1, 0, 1, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 5, 8, 32, 5, 0, 8, ethosu_write_2[256], 0, 0, 0, T.float32(0.5), 10, "NHWC", 256, 32, 1, "int8", 4, 8, 8, 4, 0, 8, ethosu_write_1[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 64, 8, 1, 3, 3, 1, 1, 1, 1, buffer_3[0], 2608, 12, buffer[0], 80, 1, 1, 0, 1, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 6, 8, 3, 6, 0, 8, placeholder_5[48], 0, 0, 0, T.float32(0.5), 10, "NHWC", 24, 3, 1, "int8", 5, 8, 32, 5, 0, 8, ethosu_write_2[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 256, 32, 1, 3, 3, 1, 1, 1, 1, buffer_2[0], 1312, 12, buffer_1[0], 320, 0, 1, 1, 1, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 5, 8, 32, 5, 0, 8, ethosu_write_2[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 256, 32, 1, "int8", 4, 8, 8, 4, 0, 8, ethosu_write_1[256], 0, 0, 0, T.float32(0.25), 14, "NHWC", 64, 8, 1, 3, 3, 1, 1, 1, 1, buffer_3[0], 2608, 12, buffer[0], 80, 0, 1, 1, 1, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) __tvm_meta__ = None @tvm.script.ir_module class Conv2dDoubleCascade3: @T.prim_func - def main(placeholder_5: T.Buffer[(1, 16, 16, 3), "int8"], ethosu_write_1: T.Buffer[(1, 20, 4, 8), "int8"]) -> None: + def main(placeholder_5: T.Buffer[(768,), "int8"], ethosu_write_1: T.Buffer[(640,), "int8"]) -> None: # function attr dict T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}) - buffer = T.buffer_var("uint8", "") - buffer_1 = T.buffer_var("uint8", "") - buffer_2 = T.buffer_var("uint8", "") - buffer_3 = T.buffer_var("uint8", "") + buffer = T.buffer_decl([1744], "uint8") + buffer_1 = T.buffer_decl([80], "uint8") + buffer_2 = T.buffer_decl([320], "uint8") + buffer_3 = T.buffer_decl([880], "uint8") + T.preflattened_buffer(placeholder_5, [1, 16, 16, 3], 'int8', data=placeholder_5.data) + T.preflattened_buffer(ethosu_write_1, [1, 20, 4, 8], 'int8', data=ethosu_write_1.data) # body ethosu_write_2 = T.allocate([2560], "int8", "global", annotations={"disable_lower_builtin": True}) - T.evaluate(T.call_extern("ethosu_conv2d", "int8", 8, 16, 3, 8, 0, 16, T.load("int8", placeholder_5.data, 0), 0, 0, 0, T.float32(0.5), 10, "NHWC", 48, 3, 1, "int8", 8, 8, 32, 8, 0, 8, T.load("int8", ethosu_write_2, 512), 0, 0, 0, T.float32(0.25), 14, "NHWC", 256, 32, 1, 2, 3, 2, 1, 2, 1, T.load("uint8", buffer_3, 0), 880, 12, T.load("uint8", buffer_2, 0), 320, 2, 1, 0, 0, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) - T.evaluate(T.call_extern("ethosu_conv2d", "int8", 8, 8, 32, 8, 0, 8, T.load("int8", ethosu_write_2, 512), 0, 0, 0, T.float32(0.5), 10, "NHWC", 256, 32, 1, "int8", 8, 4, 8, 8, 0, 4, T.load("int8", ethosu_write_1.data, 0), 0, 0, 0, T.float32(0.25), 14, "NHWC", 32, 8, 1, 2, 3, 2, 1, 2, 1, T.load("uint8", buffer, 0), 1744, 12, T.load("uint8", buffer_1, 0), 80, 2, 1, 0, 0, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) - T.evaluate(T.call_extern("ethosu_conv2d", "int8", 12, 16, 3, 12, 0, 16, T.load("int8", placeholder_5.data, 192), 0, 0, 0, T.float32(0.5), 10, "NHWC", 48, 3, 1, "int8", 10, 8, 32, 10, 0, 8, T.load("int8", ethosu_write_2, 0), 0, 0, 0, T.float32(0.25), 14, "NHWC", 256, 32, 1, 2, 3, 2, 1, 2, 1, T.load("uint8", buffer_3, 0), 880, 12, T.load("uint8", buffer_2, 0), 320, 0, 1, 0, 0, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) - T.evaluate(T.call_extern("ethosu_conv2d", "int8", 10, 8, 32, 10, 0, 8, T.load("int8", ethosu_write_2, 0), 0, 0, 0, T.float32(0.5), 10, "NHWC", 256, 32, 1, "int8", 8, 4, 8, 8, 0, 4, T.load("int8", ethosu_write_1.data, 256), 0, 0, 0, T.float32(0.25), 14, "NHWC", 32, 8, 1, 2, 3, 2, 1, 2, 1, T.load("uint8", buffer, 0), 1744, 12, T.load("uint8", buffer_1, 0), 80, 0, 1, 0, 0, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) - T.evaluate(T.call_extern("ethosu_conv2d", "int8", 4, 16, 3, 4, 0, 16, T.load("int8", placeholder_5.data, 576), 0, 0, 0, T.float32(0.5), 10, "NHWC", 48, 3, 1, "int8", 4, 8, 32, 4, 0, 8, T.load("int8", ethosu_write_2, 0), 0, 0, 0, T.float32(0.25), 14, "NHWC", 256, 32, 1, 2, 3, 2, 1, 2, 1, T.load("uint8", buffer_3, 0), 880, 12, T.load("uint8", buffer_2, 0), 320, 0, 1, 2, 0, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) - T.evaluate(T.call_extern("ethosu_conv2d", "int8", 4, 8, 32, 4, 0, 8, T.load("int8", ethosu_write_2, 0), 0, 0, 0, T.float32(0.5), 10, "NHWC", 256, 32, 1, "int8", 4, 4, 8, 4, 0, 4, T.load("int8", ethosu_write_1.data, 512), 0, 0, 0, T.float32(0.25), 14, "NHWC", 32, 8, 1, 2, 3, 2, 1, 2, 1, T.load("uint8", buffer, 0), 1744, 12, T.load("uint8", buffer_1, 0), 80, 0, 1, 2, 0, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 8, 16, 3, 8, 0, 16, placeholder_5[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 48, 3, 1, "int8", 8, 8, 32, 8, 0, 8, ethosu_write_2[512], 0, 0, 0, T.float32(0.25), 14, "NHWC", 256, 32, 1, 2, 3, 2, 1, 2, 1, buffer_3[0], 880, 12, buffer_2[0], 320, 2, 1, 0, 0, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 8, 8, 32, 8, 0, 8, ethosu_write_2[512], 0, 0, 0, T.float32(0.5), 10, "NHWC", 256, 32, 1, "int8", 8, 4, 8, 8, 0, 4, ethosu_write_1[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 32, 8, 1, 2, 3, 2, 1, 2, 1, buffer[0], 1744, 12, buffer_1[0], 80, 2, 1, 0, 0, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 12, 16, 3, 12, 0, 16, placeholder_5[192], 0, 0, 0, T.float32(0.5), 10, "NHWC", 48, 3, 1, "int8", 10, 8, 32, 10, 0, 8, ethosu_write_2[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 256, 32, 1, 2, 3, 2, 1, 2, 1, buffer_3[0], 880, 12, buffer_2[0], 320, 0, 1, 0, 0, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 10, 8, 32, 10, 0, 8, ethosu_write_2[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 256, 32, 1, "int8", 8, 4, 8, 8, 0, 4, ethosu_write_1[256], 0, 0, 0, T.float32(0.25), 14, "NHWC", 32, 8, 1, 2, 3, 2, 1, 2, 1, buffer[0], 1744, 12, buffer_1[0], 80, 0, 1, 0, 0, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 4, 16, 3, 4, 0, 16, placeholder_5[576], 0, 0, 0, T.float32(0.5), 10, "NHWC", 48, 3, 1, "int8", 4, 8, 32, 4, 0, 8, ethosu_write_2[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 256, 32, 1, 2, 3, 2, 1, 2, 1, buffer_3[0], 880, 12, buffer_2[0], 320, 0, 1, 2, 0, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 4, 8, 32, 4, 0, 8, ethosu_write_2[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 256, 32, 1, "int8", 4, 4, 8, 4, 0, 4, ethosu_write_1[512], 0, 0, 0, T.float32(0.25), 14, "NHWC", 32, 8, 1, 2, 3, 2, 1, 2, 1, buffer[0], 1744, 12, buffer_1[0], 80, 0, 1, 2, 0, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) __tvm_meta__ = None @tvm.script.ir_module class Conv2dDoubleCascade4: @T.prim_func - def main(placeholder_5: T.Buffer[(1, 8, 1, 8, 16), "int8"], ethosu_write_1: T.Buffer[(1, 8, 2, 8, 16), "int8"]) -> None: + def main(placeholder_5: T.Buffer[(1024,), "int8"], ethosu_write_1: T.Buffer[(2048,), "int8"]) -> None: # function attr dict T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}) - buffer = T.buffer_var("uint8", "") - buffer_1 = T.buffer_var("uint8", "") - buffer_2 = T.buffer_var("uint8", "") - buffer_3 = T.buffer_var("uint8", "") + buffer = T.buffer_decl([1456], "uint8") + buffer_1 = T.buffer_decl([352], "uint8") + buffer_2 = T.buffer_decl([272], "uint8") + buffer_3 = T.buffer_decl([11040], "uint8") + T.preflattened_buffer(placeholder_5, [1, 8, 1, 8, 16], 'int8', data=placeholder_5.data) + T.preflattened_buffer(ethosu_write_1, [1, 8, 2, 8, 16], 'int8', data=ethosu_write_1.data) # body ethosu_write_2 = T.allocate([2304], "int8", "global", annotations={"disable_lower_builtin": True}) - T.evaluate(T.call_extern("ethosu_conv2d", "int8", 6, 8, 3, 6, 0, 8, T.load("int8", placeholder_5.data, 0), 0, 0, 0, T.float32(0.5), 10, "NHCWB16", 128, 16, 1, "int8", 5, 8, 35, 5, 0, 8, T.load("int8", ethosu_write_2, 384), 0, 0, 0, T.float32(0.25), 14, "NHCWB16", 384, 16, 128, 3, 3, 1, 1, 1, 1, T.load("uint8", buffer, 0), 1456, 12, T.load("uint8", buffer_1, 0), 352, 1, 1, 0, 1, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) - T.evaluate(T.call_extern("ethosu_conv2d", "int8", 5, 8, 35, 5, 0, 8, T.load("int8", ethosu_write_2, 384), 0, 0, 0, T.float32(0.5), 10, "NHCWB16", 384, 16, 128, "int8", 4, 8, 26, 4, 0, 8, T.load("int8", ethosu_write_1.data, 0), 0, 0, 0, T.float32(0.25), 14, "NHCWB16", 256, 16, 128, 3, 3, 1, 1, 1, 1, T.load("uint8", buffer_3, 0), 11040, 12, T.load("uint8", buffer_2, 0), 272, 1, 1, 0, 1, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) - T.evaluate(T.call_extern("ethosu_conv2d", "int8", 6, 8, 3, 6, 0, 8, T.load("int8", placeholder_5.data, 256), 0, 0, 0, T.float32(0.5), 10, "NHCWB16", 128, 16, 1, "int8", 5, 8, 35, 5, 0, 8, T.load("int8", ethosu_write_2, 0), 0, 0, 0, T.float32(0.25), 14, "NHCWB16", 384, 16, 128, 3, 3, 1, 1, 1, 1, T.load("uint8", buffer, 0), 1456, 12, T.load("uint8", buffer_1, 0), 352, 0, 1, 1, 1, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) - T.evaluate(T.call_extern("ethosu_conv2d", "int8", 5, 8, 35, 5, 0, 8, T.load("int8", ethosu_write_2, 0), 0, 0, 0, T.float32(0.5), 10, "NHCWB16", 384, 16, 128, "int8", 4, 8, 26, 4, 0, 8, T.load("int8", ethosu_write_1.data, 1024), 0, 0, 0, T.float32(0.25), 14, "NHCWB16", 256, 16, 128, 3, 3, 1, 1, 1, 1, T.load("uint8", buffer_3, 0), 11040, 12, T.load("uint8", buffer_2, 0), 272, 0, 1, 1, 1, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 6, 8, 3, 6, 0, 8, placeholder_5[0], 0, 0, 0, T.float32(0.5), 10, "NHCWB16", 128, 16, 1, "int8", 5, 8, 35, 5, 0, 8, ethosu_write_2[384], 0, 0, 0, T.float32(0.25), 14, "NHCWB16", 384, 16, 128, 3, 3, 1, 1, 1, 1, buffer[0], 1456, 12, buffer_1[0], 352, 1, 1, 0, 1, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 5, 8, 35, 5, 0, 8, ethosu_write_2[384], 0, 0, 0, T.float32(0.5), 10, "NHCWB16", 384, 16, 128, "int8", 4, 8, 26, 4, 0, 8, ethosu_write_1[0], 0, 0, 0, T.float32(0.25), 14, "NHCWB16", 256, 16, 128, 3, 3, 1, 1, 1, 1, buffer_3[0], 11040, 12, buffer_2[0], 272, 1, 1, 0, 1, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 6, 8, 3, 6, 0, 8, placeholder_5[256], 0, 0, 0, T.float32(0.5), 10, "NHCWB16", 128, 16, 1, "int8", 5, 8, 35, 5, 0, 8, ethosu_write_2[0], 0, 0, 0, T.float32(0.25), 14, "NHCWB16", 384, 16, 128, 3, 3, 1, 1, 1, 1, buffer[0], 1456, 12, buffer_1[0], 352, 0, 1, 1, 1, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 5, 8, 35, 5, 0, 8, ethosu_write_2[0], 0, 0, 0, T.float32(0.5), 10, "NHCWB16", 384, 16, 128, "int8", 4, 8, 26, 4, 0, 8, ethosu_write_1[1024], 0, 0, 0, T.float32(0.25), 14, "NHCWB16", 256, 16, 128, 3, 3, 1, 1, 1, 1, buffer_3[0], 11040, 12, buffer_2[0], 272, 0, 1, 1, 1, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) __tvm_meta__ = None @tvm.script.ir_module class Conv2dDoubleCascade5: @T.prim_func - def main(placeholder: T.Buffer[(1, 8, 8, 3), "int8"], ethosu_write: T.Buffer[(1, 32, 32, 8), "int8"]) -> None: + def main(placeholder: T.Buffer[(192,), "int8"], ethosu_write: T.Buffer[(8192,), "int8"]) -> None: # function attr dict T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}) - buffer = T.buffer_var("uint8", "") - buffer_1 = T.buffer_var("uint8", "") - buffer_2 = T.buffer_var("uint8", "") - buffer_3 = T.buffer_var("uint8", "") + buffer = T.buffer_decl([160], "uint8") + buffer_1 = T.buffer_decl([320], "uint8") + buffer_2 = T.buffer_decl([304], "uint8") + buffer_3 = T.buffer_decl([80], "uint8") + T.preflattened_buffer(placeholder, [1, 8, 8, 3], 'int8', data=placeholder.data) + T.preflattened_buffer(ethosu_write, [1, 32, 32, 8], 'int8', data=ethosu_write.data) # body ethosu_write_1 = T.allocate([4096], "int8", "global", annotations={"disable_lower_builtin":True}) - T.evaluate(T.call_extern("ethosu_conv2d", "int8", 4, 8, 3, 4, 0, 8, T.load("int8", placeholder.data, 0), 0, 0, 0, T.float32(0.5), 10, "NHWC", 24, 3, 1, "int8", 8, 16, 32, 8, 0, 16, T.load("int8", ethosu_write_1, 0), 0, 0, 0, T.float32(0.25), 14, "NHWC", 512, 32, 1, 1, 1, 1, 1, 1, 1, T.load("uint8", buffer, 0), 160, 12, T.load("uint8", buffer_1, 0), 320, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "ZEROS", dtype="handle")) - T.evaluate(T.call_extern("ethosu_conv2d", "int8", 8, 16, 32, 8, 0, 16, T.load("int8", ethosu_write_1, 0), 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 32, 8, 16, 0, 32, T.load("int8", ethosu_write.data, 0), 0, 0, 0, T.float32(0.25), 14, "NHWC", 256, 8, 1, 1, 1, 1, 1, 1, 1, T.load("uint8", buffer_2, 0), 304, 12, T.load("uint8", buffer_3, 0), 80, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "ZEROS", dtype="handle")) - T.evaluate(T.call_extern("ethosu_conv2d", "int8", 4, 8, 3, 4, 0, 8, T.load("int8", placeholder.data, 96), 0, 0, 0, T.float32(0.5), 10, "NHWC", 24, 3, 1, "int8", 8, 16, 32, 8, 0, 16, T.load("int8", ethosu_write_1, 0), 0, 0, 0, T.float32(0.25), 14, "NHWC", 512, 32, 1, 1, 1, 1, 1, 1, 1, T.load("uint8", buffer, 0), 160, 12, T.load("uint8", buffer_1, 0), 320, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "ZEROS", dtype="handle")) - T.evaluate(T.call_extern("ethosu_conv2d", "int8", 8, 16, 32, 8, 0, 16, T.load("int8", ethosu_write_1, 0), 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 32, 8, 16, 0, 32, T.load("int8", ethosu_write.data, 4096), 0, 0, 0, T.float32(0.25), 14, "NHWC", 256, 8, 1, 1, 1, 1, 1, 1, 1, T.load("uint8", buffer_2, 0), 304, 12, T.load("uint8", buffer_3, 0), 80, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "ZEROS", dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 4, 8, 3, 4, 0, 8, placeholder[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 24, 3, 1, "int8", 8, 16, 32, 8, 0, 16, ethosu_write_1[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 512, 32, 1, 1, 1, 1, 1, 1, 1, buffer[0], 160, 12, buffer_1[0], 320, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "ZEROS", dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 8, 16, 32, 8, 0, 16, ethosu_write_1[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 32, 8, 16, 0, 32, ethosu_write[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 256, 8, 1, 1, 1, 1, 1, 1, 1, buffer_2[0], 304, 12, buffer_3[0], 80, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "ZEROS", dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 4, 8, 3, 4, 0, 8, placeholder[96], 0, 0, 0, T.float32(0.5), 10, "NHWC", 24, 3, 1, "int8", 8, 16, 32, 8, 0, 16, ethosu_write_1[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 512, 32, 1, 1, 1, 1, 1, 1, 1, buffer[0], 160, 12, buffer_1[0], 320, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "ZEROS", dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 8, 16, 32, 8, 0, 16, ethosu_write_1[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 32, 8, 16, 0, 32, ethosu_write[4096], 0, 0, 0, T.float32(0.25), 14, "NHWC", 256, 8, 1, 1, 1, 1, 1, 1, 1, buffer_2[0], 304, 12, buffer_3[0], 80, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "ZEROS", dtype="handle")) __tvm_meta__ = None @tvm.script.ir_module class Conv2dDoubleCascade6: @T.prim_func - def main(placeholder: T.Buffer[(1, 8, 1, 8, 16), "int8"], ethosu_write: T.Buffer[(1, 32, 2, 32, 16), "int8"]) -> None: + def main(placeholder: T.Buffer[(1024,), "int8"], ethosu_write: T.Buffer[(32768,), "int8"]) -> None: # function attr dict T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}) - buffer = T.buffer_var("uint8", "") - buffer_1 = T.buffer_var("uint8", "") - buffer_2 = T.buffer_var("uint8", "") - buffer_3 = T.buffer_var("uint8", "") + buffer = T.buffer_decl([1456], "uint8") + buffer_1 = T.buffer_decl([352], "uint8") + buffer_2 = T.buffer_decl([11040], "uint8") + buffer_3 = T.buffer_decl([272], "uint8") + T.preflattened_buffer(placeholder, [1, 8, 1, 8, 16], 'int8', data=placeholder.data) + T.preflattened_buffer(ethosu_write, [1, 32, 2, 32, 16], 'int8', data=ethosu_write.data) # body ethosu_write_1 = T.allocate([12288], "int8", "global", annotations={"disable_lower_builtin":True}) - T.evaluate(T.call_extern("ethosu_conv2d", "int8", 8, 8, 3, 8, 0, 8, T.load("int8", placeholder.data, 0), 0, 0, 0, T.float32(0.5), 10, "NHCWB16", 128, 16, 1, "int8", 16, 16, 35, 16, 0, 16, T.load("int8", ethosu_write_1, 0), 0, 0, 0, T.float32(0.25), 14, "NHCWB16", 768, 16, 256, 3, 3, 1, 1, 1, 1, T.load("uint8", buffer, 0), 1456, 12, T.load("uint8", buffer_1, 0), 352, 1, 1, 1, 1, "NONE", 0, 0, "TFL", "NEAREST", dtype="handle")) - T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 35, 16, 0, 16, T.load("int8", ethosu_write_1, 0), 0, 0, 0, T.float32(0.5), 10, "NHCWB16", 768, 16, 256, "int8", 32, 32, 26, 32, 0, 32, T.load("int8", ethosu_write.data, 0), 0, 0, 0, T.float32(0.25), 14, "NHCWB16", 1024, 16, 512, 3, 3, 1, 1, 1, 1, T.load("uint8", buffer_2, 0), 11040, 12, T.load("uint8", buffer_3, 0), 272, 1, 1, 1, 1, "NONE", 0, 0, "TFL", "NEAREST", dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 8, 8, 3, 8, 0, 8, placeholder[0], 0, 0, 0, T.float32(0.5), 10, "NHCWB16", 128, 16, 1, "int8", 16, 16, 35, 16, 0, 16, ethosu_write_1[0], 0, 0, 0, T.float32(0.25), 14, "NHCWB16", 768, 16, 256, 3, 3, 1, 1, 1, 1, buffer[0], 1456, 12, buffer_1[0], 352, 1, 1, 1, 1, "NONE", 0, 0, "TFL", "NEAREST", dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 35, 16, 0, 16, ethosu_write_1[0], 0, 0, 0, T.float32(0.5), 10, "NHCWB16", 768, 16, 256, "int8", 32, 32, 26, 32, 0, 32, ethosu_write[0], 0, 0, 0, T.float32(0.25), 14, "NHCWB16", 1024, 16, 512, 3, 3, 1, 1, 1, 1, buffer_2[0], 11040, 12, buffer_3[0], 272, 1, 1, 1, 1, "NONE", 0, 0, "TFL", "NEAREST", dtype="handle")) __tvm_meta__ = None # fmt: on @@ -591,26 +603,30 @@ def _get_func( @tvm.script.ir_module class Conv2dInlineCopy1: @T.prim_func - def main(placeholder_3: T.Buffer[(1, 10, 12, 8), "int8"], ethosu_write_1: T.Buffer[(1, 8, 8, 16), "int8"]) -> None: + def main(placeholder_3: T.Buffer[(960,), "int8"], ethosu_write_1: T.Buffer[(1024,), "int8"]) -> None: # function attr dict T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}) - buffer = T.buffer_var("uint8", "") - buffer_1 = T.buffer_var("uint8", "") + buffer = T.buffer_decl([848], "uint8") + buffer_1 = T.buffer_decl([160], "uint8") + T.preflattened_buffer(placeholder_3, [1, 10, 12, 8], 'int8', data=placeholder_3.data) + T.preflattened_buffer(ethosu_write_1, [1, 8, 8, 16], 'int8', data=ethosu_write_1.data) # body - T.evaluate(T.call_extern("ethosu_conv2d", "int8", 8, 8, 4, 8, 0, 8, T.load("int8", placeholder_3.data, 120), 0, 0, 0, T.float32(0.5), 10, "NHWC", 96, 8, 1, "int8", 8, 8, 16, 8, 0, 8, T.load("int8", ethosu_write_1.data, 0), 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 16, 1, 3, 3, 1, 1, 1, 1, T.load("uint8", buffer, 0), 848, 12, T.load("uint8", buffer_1, 0), 160, 1, 1, 1, 1, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 8, 8, 4, 8, 0, 8, placeholder_3[120], 0, 0, 0, T.float32(0.5), 10, "NHWC", 96, 8, 1, "int8", 8, 8, 16, 8, 0, 8, ethosu_write_1[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 16, 1, 3, 3, 1, 1, 1, 1, buffer[0], 848, 12, buffer_1[0], 160, 1, 1, 1, 1, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) __tvm_meta__ = None @tvm.script.ir_module class Conv2dInlineCopy2: @T.prim_func - def main(placeholder_3: T.Buffer[(1, 7, 9, 5), "int8"], ethosu_write_1: T.Buffer[(1, 3, 5, 16), "int8"]) -> None: + def main(placeholder_3: T.Buffer[(315,), "int8"], ethosu_write_1: T.Buffer[(240,), "int8"]) -> None: # function attr dict T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}) - buffer = T.buffer_var("uint8", "") - buffer_1 = T.buffer_var("uint8", "") + buffer = T.buffer_decl([160], "uint8") + buffer_1 = T.buffer_decl([656], "uint8") + T.preflattened_buffer(placeholder_3, [1, 7, 9, 5], 'int8', data=placeholder_3.data) + T.preflattened_buffer(ethosu_write_1, [1, 3, 5, 16], 'int8', data=ethosu_write_1.data) # body - T.evaluate(T.call_extern("ethosu_conv2d", "int8", 3, 5, 3, 3, 0, 5, T.load("int8", placeholder_3.data, 146), 0, 0, 0, T.float32(0.5), 10, "NHWC", 45, 5, 1, "int8", 3, 5, 16, 3, 0, 5, T.load("int8", ethosu_write_1.data, 0), 0, 0, 0, T.float32(0.25), 14, "NHWC", 80, 16, 1, 3, 3, 1, 1, 1, 1, T.load("uint8", buffer_1, 0), 656, 12, T.load("uint8", buffer, 0), 160, 1, 1, 1, 1, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 3, 5, 3, 3, 0, 5, placeholder_3[146], 0, 0, 0, T.float32(0.5), 10, "NHWC", 45, 5, 1, "int8", 3, 5, 16, 3, 0, 5, ethosu_write_1[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 80, 16, 1, 3, 3, 1, 1, 1, 1, buffer_1[0], 656, 12, buffer[0], 160, 1, 1, 1, 1, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) __tvm_meta__ = None # fmt: on @@ -646,56 +662,64 @@ def _get_func(ifm_shape, lower, upper, ofm_channels=16): @tvm.script.ir_module class Conv2dInlineReshape1: @T.prim_func - def main(placeholder_3: T.Buffer[(4, 6, 8, 1), "int8"], ethosu_write_1: T.Buffer[(1, 8, 6, 16), "int8"]) -> None: + def main(placeholder_3: T.Buffer[(192,), "int8"], ethosu_write_1: T.Buffer[(768,), "int8"]) -> None: # function attr dict T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}) - buffer = T.buffer_var("uint8", "") - buffer_1 = T.buffer_var("uint8", "") + buffer = T.buffer_decl([160], "uint8") + buffer_1 = T.buffer_decl([848], "uint8") + T.preflattened_buffer(placeholder_3, [4, 6, 8, 1], 'int8', data=placeholder_3.data) + T.preflattened_buffer(ethosu_write_1, [1, 8, 6, 16], 'int8', data=ethosu_write_1.data) # body - T.evaluate(T.call_extern("ethosu_conv2d", "int8", 5, 6, 4, 5, 0, 6, T.load("int8", placeholder_3.data, 0), 0, 0, 0, T.float32(0.5), 10, "NHWC", 24, 4, 1, "int8", 4, 6, 16, 4, 0, 6, T.load("int8", ethosu_write_1.data, 0), 0, 0, 0, T.float32(0.25), 14, "NHWC", 96, 16, 1, 3, 3, 1, 1, 1, 1, T.load("uint8", buffer_1, 0), 848, 12, T.load("uint8", buffer, 0), 160, 1, 1, 0, 1, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) - T.evaluate(T.call_extern("ethosu_conv2d", "int8", 5, 6, 4, 5, 0, 6, T.load("int8", placeholder_3.data, 72), 0, 0, 0, T.float32(0.5), 10, "NHWC", 24, 4, 1, "int8", 4, 6, 16, 4, 0, 6, T.load("int8", ethosu_write_1.data, 384), 0, 0, 0, T.float32(0.25), 14, "NHWC", 96, 16, 1, 3, 3, 1, 1, 1, 1, T.load("uint8", buffer_1, 0), 848, 12, T.load("uint8", buffer, 0), 160, 0, 1, 1, 1, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 5, 6, 4, 5, 0, 6, placeholder_3[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 24, 4, 1, "int8", 4, 6, 16, 4, 0, 6, ethosu_write_1[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 96, 16, 1, 3, 3, 1, 1, 1, 1, buffer_1[0], 848, 12, buffer[0], 160, 1, 1, 0, 1, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 5, 6, 4, 5, 0, 6, placeholder_3[72], 0, 0, 0, T.float32(0.5), 10, "NHWC", 24, 4, 1, "int8", 4, 6, 16, 4, 0, 6, ethosu_write_1[384], 0, 0, 0, T.float32(0.25), 14, "NHWC", 96, 16, 1, 3, 3, 1, 1, 1, 1, buffer_1[0], 848, 12, buffer[0], 160, 0, 1, 1, 1, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) __tvm_meta__ = None @tvm.script.ir_module class Conv2dInlineReshape2: @T.prim_func - def main(placeholder_3: T.Buffer[(1, 24, 8), "int8"], ethosu_write_1: T.Buffer[(1, 8, 6, 16), "int8"]) -> None: + def main(placeholder_3: T.Buffer[(192,), "int8"], ethosu_write_1: T.Buffer[(768,), "int8"]) -> None: # function attr dict T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}) - buffer = T.buffer_var("uint8", "") - buffer_1 = T.buffer_var("uint8", "") + buffer = T.buffer_decl([160], "uint8") + buffer_1 = T.buffer_decl([848], "uint8") + T.preflattened_buffer(placeholder_3, [1, 24, 8], 'int8', data=placeholder_3.data) + T.preflattened_buffer(ethosu_write_1, [1, 8, 6, 16], 'int8', data=ethosu_write_1.data) # body - T.evaluate(T.call_extern("ethosu_conv2d", "int8", 5, 6, 4, 5, 0, 6, T.load("int8", placeholder_3.data, 0), 0, 0, 0, T.float32(0.5), 10, "NHWC", 24, 4, 1, "int8", 4, 6, 16, 4, 0, 6, T.load("int8", ethosu_write_1.data, 0), 0, 0, 0, T.float32(0.25), 14, "NHWC", 96, 16, 1, 3, 3, 1, 1, 1, 1, T.load("uint8", buffer_1, 0), 848, 12, T.load("uint8", buffer, 0), 160, 1, 1, 0, 1, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) - T.evaluate(T.call_extern("ethosu_conv2d", "int8", 5, 6, 4, 5, 0, 6, T.load("int8", placeholder_3.data, 72), 0, 0, 0, T.float32(0.5), 10, "NHWC", 24, 4, 1, "int8", 4, 6, 16, 4, 0, 6, T.load("int8", ethosu_write_1.data, 384), 0, 0, 0, T.float32(0.25), 14, "NHWC", 96, 16, 1, 3, 3, 1, 1, 1, 1, T.load("uint8", buffer_1, 0), 848, 12, T.load("uint8", buffer, 0), 160, 0, 1, 1, 1, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 5, 6, 4, 5, 0, 6, placeholder_3[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 24, 4, 1, "int8", 4, 6, 16, 4, 0, 6, ethosu_write_1[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 96, 16, 1, 3, 3, 1, 1, 1, 1, buffer_1[0], 848, 12, buffer[0], 160, 1, 1, 0, 1, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 5, 6, 4, 5, 0, 6, placeholder_3[72], 0, 0, 0, T.float32(0.5), 10, "NHWC", 24, 4, 1, "int8", 4, 6, 16, 4, 0, 6, ethosu_write_1[384], 0, 0, 0, T.float32(0.25), 14, "NHWC", 96, 16, 1, 3, 3, 1, 1, 1, 1, buffer_1[0], 848, 12, buffer[0], 160, 0, 1, 1, 1, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) __tvm_meta__ = None @tvm.script.ir_module class Conv2dInlineReshape3: @T.prim_func - def main(placeholder_3: T.Buffer[(192, 1), "int8"], ethosu_write_1: T.Buffer[(1, 8, 6, 16), "int8"]) -> None: + def main(placeholder_3: T.Buffer[(192,), "int8"], ethosu_write_1: T.Buffer[(768,), "int8"]) -> None: # function attr dict T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}) - buffer = T.buffer_var("uint8", "") - buffer_1 = T.buffer_var("uint8", "") + buffer = T.buffer_decl([160], "uint8") + buffer_1 = T.buffer_decl([848], "uint8") + T.preflattened_buffer(placeholder_3, [192, 1], 'int8', data=placeholder_3.data) + T.preflattened_buffer(ethosu_write_1, [1, 8, 6, 16], 'int8', data=ethosu_write_1.data) # body - T.evaluate(T.call_extern("ethosu_conv2d", "int8", 5, 6, 4, 5, 0, 6, T.load("int8", placeholder_3.data, 0), 0, 0, 0, T.float32(0.5), 10, "NHWC", 24, 4, 1, "int8", 4, 6, 16, 4, 0, 6, T.load("int8", ethosu_write_1.data, 0), 0, 0, 0, T.float32(0.25), 14, "NHWC", 96, 16, 1, 3, 3, 1, 1, 1, 1, T.load("uint8", buffer_1, 0), 848, 12, T.load("uint8", buffer, 0), 160, 1, 1, 0, 1, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) - T.evaluate(T.call_extern("ethosu_conv2d", "int8", 5, 6, 4, 5, 0, 6, T.load("int8", placeholder_3.data, 72), 0, 0, 0, T.float32(0.5), 10, "NHWC", 24, 4, 1, "int8", 4, 6, 16, 4, 0, 6, T.load("int8", ethosu_write_1.data, 384), 0, 0, 0, T.float32(0.25), 14, "NHWC", 96, 16, 1, 3, 3, 1, 1, 1, 1, T.load("uint8", buffer_1, 0), 848, 12, T.load("uint8", buffer, 0), 160, 0, 1, 1, 1, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 5, 6, 4, 5, 0, 6, placeholder_3[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 24, 4, 1, "int8", 4, 6, 16, 4, 0, 6, ethosu_write_1[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 96, 16, 1, 3, 3, 1, 1, 1, 1, buffer_1[0], 848, 12, buffer[0], 160, 1, 1, 0, 1, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 5, 6, 4, 5, 0, 6, placeholder_3[72], 0, 0, 0, T.float32(0.5), 10, "NHWC", 24, 4, 1, "int8", 4, 6, 16, 4, 0, 6, ethosu_write_1[384], 0, 0, 0, T.float32(0.25), 14, "NHWC", 96, 16, 1, 3, 3, 1, 1, 1, 1, buffer_1[0], 848, 12, buffer[0], 160, 0, 1, 1, 1, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) __tvm_meta__ = None @tvm.script.ir_module class Conv2dInlineReshape4: @T.prim_func - def main(placeholder_3: T.Buffer[(192,), "int8"], ethosu_write_1: T.Buffer[(1, 8, 6, 16), "int8"]) -> None: + def main(placeholder_3: T.Buffer[(192,), "int8"], ethosu_write_1: T.Buffer[(768,), "int8"]) -> None: # function attr dict T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}) - buffer = T.buffer_var("uint8", "") - buffer_1 = T.buffer_var("uint8", "") + buffer = T.buffer_decl([160], "uint8") + buffer_1 = T.buffer_decl([848], "uint8") + T.preflattened_buffer(placeholder_3, [192], 'int8', data=placeholder_3.data) + T.preflattened_buffer(ethosu_write_1, [1, 8, 6, 16], 'int8', data=ethosu_write_1.data) # body - T.evaluate(T.call_extern("ethosu_conv2d", "int8", 5, 6, 4, 5, 0, 6, T.load("int8", placeholder_3.data, 0), 0, 0, 0, T.float32(0.5), 10, "NHWC", 24, 4, 1, "int8", 4, 6, 16, 4, 0, 6, T.load("int8", ethosu_write_1.data, 0), 0, 0, 0, T.float32(0.25), 14, "NHWC", 96, 16, 1, 3, 3, 1, 1, 1, 1, T.load("uint8", buffer_1, 0), 848, 12, T.load("uint8", buffer, 0), 160, 1, 1, 0, 1, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) - T.evaluate(T.call_extern("ethosu_conv2d", "int8", 5, 6, 4, 5, 0, 6, T.load("int8", placeholder_3.data, 72), 0, 0, 0, T.float32(0.5), 10, "NHWC", 24, 4, 1, "int8", 4, 6, 16, 4, 0, 6, T.load("int8", ethosu_write_1.data, 384), 0, 0, 0, T.float32(0.25), 14, "NHWC", 96, 16, 1, 3, 3, 1, 1, 1, 1, T.load("uint8", buffer_1, 0), 848, 12, T.load("uint8", buffer, 0), 160, 0, 1, 1, 1, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 5, 6, 4, 5, 0, 6, placeholder_3[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 24, 4, 1, "int8", 4, 6, 16, 4, 0, 6, ethosu_write_1[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 96, 16, 1, 3, 3, 1, 1, 1, 1, buffer_1[0], 848, 12, buffer[0], 160, 1, 1, 0, 1, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 5, 6, 4, 5, 0, 6, placeholder_3[72], 0, 0, 0, T.float32(0.5), 10, "NHWC", 24, 4, 1, "int8", 4, 6, 16, 4, 0, 6, ethosu_write_1[384], 0, 0, 0, T.float32(0.25), 14, "NHWC", 96, 16, 1, 3, 3, 1, 1, 1, 1, buffer_1[0], 848, 12, buffer[0], 160, 0, 1, 1, 1, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) __tvm_meta__ = None # fmt: on diff --git a/tests/python/contrib/test_ethosu/test_replace_copy.py b/tests/python/contrib/test_ethosu/test_replace_copy.py index 7aee57d548fe..4bfbae5f03b7 100644 --- a/tests/python/contrib/test_ethosu/test_replace_copy.py +++ b/tests/python/contrib/test_ethosu/test_replace_copy.py @@ -31,17 +31,19 @@ @tvm.script.ir_module class ReferenceModule: @T.prim_func - def main(placeholder_3: T.Buffer[(1, 16, 16, 32), "int8"], ethosu_write_1: T.Buffer[(1, 16, 16, 8), "int8"]) -> None: + def main(placeholder_3: T.Buffer[(8192,), "int8"], ethosu_write_1: T.Buffer[(2048,), "int8"]) -> None: # function attr dict T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}) - buffer = T.buffer_var("uint8", "") - buffer_1 = T.buffer_var("uint8", "") + buffer = T.buffer_decl([80], "uint8") + buffer_1 = T.buffer_decl([304], "uint8") + T.preflattened_buffer(placeholder_3, [1, 16, 16, 32], dtype="int8", data=placeholder_3.data) + T.preflattened_buffer(ethosu_write_1, [1, 16, 16, 8], dtype="int8", data=ethosu_write_1.data) # body placeholder_global = T.allocate([304], "uint8", "global", annotations={"disable_lower_builtin": True}) placeholder_d_global = T.allocate([80], "uint8", "global", annotations={"disable_lower_builtin": True}) - T.evaluate(T.call_extern("ethosu_copy", T.load("uint8", buffer_1, 0), 304, T.load("uint8", placeholder_global, 0), dtype="handle")) - T.evaluate(T.call_extern("ethosu_copy", T.load("uint8", buffer, 0), 80, T.load("uint8", placeholder_d_global, 0), dtype="handle")) - T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 32, 16, 0, 16, T.load("int8", placeholder_3.data, 0), 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 16, 8, 16, 0, 16, T.load("int8", ethosu_write_1.data, 0), 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, T.load("uint8", placeholder_global, 0), 304, 12, T.load("uint8", placeholder_d_global, 0), 80, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) + T.evaluate(T.call_extern("ethosu_copy", buffer_1[0], 304, placeholder_global[0], dtype="handle")) + T.evaluate(T.call_extern("ethosu_copy", buffer[0], 80, placeholder_d_global[0], dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 32, 16, 0, 16, placeholder_3[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 16, 8, 16, 0, 16, ethosu_write_1[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, placeholder_global[0], 304, 12, placeholder_d_global[0], 80, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) __tvm_meta__ = None # fmt: on @@ -75,22 +77,26 @@ def _get_func(): @tvm.script.ir_module class WeightStream: @T.prim_func - def main(placeholder_5: T.Buffer[(1, 16, 16, 32), "int8"], ethosu_write_1: T.Buffer[(1, 16, 16, 16), "int8"]) -> None: + def main(placeholder_5: T.Buffer[(8192,), "int8"], ethosu_write_1: T.Buffer[(4096,), "int8"]) -> None: # function attr dict T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}) - buffer = T.buffer_var("uint8", "") - buffer_1 = T.buffer_var("uint8", "") - buffer_2 = T.buffer_var("uint8", "") - buffer_3 = T.buffer_var("uint8", "") + buffer = T.buffer_decl([416], "uint8") + buffer_1 = T.buffer_decl([112], "uint8") + buffer_2 = T.buffer_decl([272], "uint8") + buffer_3 = T.buffer_decl([64], "uint8") + T.preflattened_buffer(placeholder_5, [1, 16, 16, 32], dtype="int8", data=placeholder_5.data) + T.preflattened_buffer(ethosu_write_1, [1, 16, 16, 16], dtype="int8", data=ethosu_write_1.data) # body - placeholder_global = T.allocate([416], "uint8", "global", annotations={"disable_lower_builtin": True}) - placeholder_d_global = T.allocate([112], "uint8", "global", annotations={"disable_lower_builtin": True}) - T.evaluate(T.call_extern("ethosu_copy", T.load("uint8", buffer, 0), 416, T.load("uint8", placeholder_global, 0), dtype="handle")) - T.evaluate(T.call_extern("ethosu_copy", T.load("uint8", buffer_1, 0), 112, T.load("uint8", placeholder_d_global, 0), dtype="handle")) - T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 32, 16, 0, 16, T.load("int8", placeholder_5.data, 0), 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 16, 10, 16, 0, 16, T.load("int8", ethosu_write_1.data, 0), 0, 0, 0, T.float32(0.25), 14, "NHWC", 256, 16, 1, 1, 1, 1, 1, 1, 1, T.load("uint8", placeholder_global, 0), 416, 12, T.load("uint8", placeholder_d_global, 0), 112, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) - T.evaluate(T.call_extern("ethosu_copy", T.load("uint8", buffer_2, 0), 272, T.load("uint8", placeholder_global, 0), dtype="handle")) - T.evaluate(T.call_extern("ethosu_copy", T.load("uint8", buffer_3, 0), 64, T.load("uint8", placeholder_d_global, 0), dtype="handle")) - T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 32, 16, 0, 16, T.load("int8", placeholder_5.data, 0), 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 16, 6, 16, 0, 16, T.load("int8", ethosu_write_1.data, 10), 0, 0, 0, T.float32(0.25), 14, "NHWC", 256, 16, 1, 1, 1, 1, 1, 1, 1, T.load("uint8", placeholder_global, 0), 272, 12, T.load("uint8", placeholder_d_global, 0), 64, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) + placeholder_global_unrolled_iter_0 = T.allocate([416], "uint8", "global", annotations={"disable_lower_builtin": True}) + placeholder_global_unrolled_iter_1 = T.buffer_decl([272], "uint8", data=placeholder_global_unrolled_iter_0.data) + placeholder_d_global_unrolled_iter_0 = T.allocate([112], "uint8", "global", annotations={"disable_lower_builtin": True}) + placeholder_d_global_unrolled_iter_1 = T.buffer_decl([64], dtype="uint8", data=placeholder_d_global_unrolled_iter_0.data) + T.evaluate(T.call_extern("ethosu_copy", buffer[0], 416, placeholder_global_unrolled_iter_0[0], dtype="handle")) + T.evaluate(T.call_extern("ethosu_copy", buffer_1[0], 112, placeholder_d_global_unrolled_iter_0[0], dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 32, 16, 0, 16, placeholder_5[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 16, 10, 16, 0, 16, ethosu_write_1[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 256, 16, 1, 1, 1, 1, 1, 1, 1, placeholder_global_unrolled_iter_0[0], 416, 12, placeholder_d_global_unrolled_iter_0[0], 112, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) + T.evaluate(T.call_extern("ethosu_copy", buffer_2[0], 272, placeholder_global_unrolled_iter_1[0], dtype="handle")) + T.evaluate(T.call_extern("ethosu_copy", buffer_3[0], 64, placeholder_d_global_unrolled_iter_1[0], dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 32, 16, 0, 16, placeholder_5[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 16, 6, 16, 0, 16, ethosu_write_1[10], 0, 0, 0, T.float32(0.25), 14, "NHWC", 256, 16, 1, 1, 1, 1, 1, 1, 1, placeholder_global_unrolled_iter_1[0], 272, 12, placeholder_d_global_unrolled_iter_1[0], 64, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) __tvm_meta__ = None # fmt: on diff --git a/tests/python/contrib/test_ethosu/test_replace_unary_elementwise.py b/tests/python/contrib/test_ethosu/test_replace_unary_elementwise.py index e1c633e1d569..498609fb15b7 100644 --- a/tests/python/contrib/test_ethosu/test_replace_unary_elementwise.py +++ b/tests/python/contrib/test_ethosu/test_replace_unary_elementwise.py @@ -33,8 +33,8 @@ def _get_unary_elementwise_args(call, include_buffers=False, remove_constants=Fa for i, arg in enumerate(args): if isinstance(arg, tvm.tir.expr.IntImm) or isinstance(arg, tvm.tir.expr.FloatImm): unary_elementwise_args.append(arg.value) - elif isinstance(arg, tvm.tir.expr.Load) and not include_buffers: - unary_elementwise_args.append(arg.index) + elif isinstance(arg, tvm.tir.expr.BufferLoad) and not include_buffers: + unary_elementwise_args.append(arg.indices[0]) else: unary_elementwise_args.append(arg) diff --git a/tests/python/contrib/test_ethosu/test_scheduler.py b/tests/python/contrib/test_ethosu/test_scheduler.py index 6a4aba4e38fc..5c6f064873ef 100644 --- a/tests/python/contrib/test_ethosu/test_scheduler.py +++ b/tests/python/contrib/test_ethosu/test_scheduler.py @@ -180,25 +180,29 @@ def test_schedule_cache_reads(): @tvm.script.ir_module class DiamondGraphTir: @T.prim_func - def main(input_buffer: T.Buffer[(1, 56, 56, 96), "int8"], output_buffer: T.Buffer[(1, 56, 56, 24), "int8"]) -> None: + def main(input_buffer: T.Buffer[(301056,), "int8"], output_buffer: T.Buffer[(75264,), "int8"]) -> None: T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}) - weight_buffer = T.buffer_var("uint8", "") - bias_buffer = T.buffer_var("uint8", "") - weight_buffer2 = T.buffer_var("uint8", "") - bias_buffer2 = T.buffer_var("uint8", "") + T.preflattened_buffer(input_buffer, [1, 56, 56, 96], dtype='int8', data=input_buffer.data) + T.preflattened_buffer(output_buffer, [1, 56, 56, 24], dtype='int8', data=output_buffer.data) - placeholder_global = T.allocate([2608], "uint8", "global", annotations={"disable_lower_builtin":True}) - placeholder_d_global = T.allocate([240], "uint8", "global", annotations={"disable_lower_builtin":True}) + weight_buffer = T.buffer_decl([2608], "uint8") + bias_buffer = T.buffer_decl([240], "uint8") + weight_buffer2 = T.buffer_decl([736], "uint8") + bias_buffer2 = T.buffer_decl([240], "uint8") + + weight_global = T.allocate([2608], "uint8", "global", annotations={"disable_lower_builtin":True}) + weight_global2 = T.buffer_decl([736], "uint8", data=weight_global.data) + bias_global = T.allocate([240], "uint8", "global", annotations={"disable_lower_builtin":True}) featuremap_buffer = T.allocate([75264], "int8", "global", annotations={"disable_lower_builtin": True}) featuremap_buffer2 = T.allocate([75264], "int8", "global", annotations={"disable_lower_builtin": True}) - T.evaluate(T.call_extern("ethosu_copy", T.load("uint8", weight_buffer, 0), 2608, T.load("uint8", placeholder_global, 0), dtype="handle")) - T.evaluate(T.call_extern("ethosu_copy", T.load("uint8", bias_buffer, 0), 240, T.load("uint8", placeholder_d_global, 0), dtype="handle")) - T.evaluate(T.call_extern("ethosu_conv2d", "int8", 56, 56, 96, 56, 0, 56, T.load("int8", input_buffer.data, 0), 0, 0, 0, T.float32(0.5), 10, "NHWC", 5376, 96, 1, "int8", 56, 56, 24, 56, 0, 56, T.load("int8", featuremap_buffer, 0), 0, 0, 0, T.float32(0.25), 14, "NHWC", 1344, 24, 1, 1, 1, 1, 1, 1, 1, T.load("uint8", placeholder_global, 0), 2608, 12, T.load("uint8", placeholder_d_global, 0), 240, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) - T.evaluate(T.call_extern("ethosu_copy", T.load("uint8", weight_buffer2, 0), 736, T.load("uint8", placeholder_global, 0), dtype="handle")) - T.evaluate(T.call_extern("ethosu_copy", T.load("uint8", bias_buffer2, 0), 240, T.load("uint8", placeholder_d_global, 0), dtype="handle")) - T.evaluate(T.call_extern("ethosu_conv2d", "int8", 56, 56, 24, 56, 0, 56, T.load("int8", featuremap_buffer, 0), 0, 0, 0, T.float32(0.5), 10, "NHWC", 1344, 24, 1, "int8", 56, 56, 24, 56, 0, 56, T.load("int8", featuremap_buffer2, 0), 0, 0, 0, T.float32(0.25), 14, "NHWC", 1344, 24, 1, 1, 1, 1, 1, 1, 1, T.load("uint8", placeholder_global, 0), 736, 12, T.load("uint8", placeholder_d_global, 0), 240, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) - T.evaluate(T.call_extern("ethosu_binary_elementwise", "int8", 56, 56, 24, 56, 0, 56, T.load("int8", featuremap_buffer, 0), 0, 0, 0, T.float32(1), 0, "NHWC", 1344, 24, 1, "int8", 56, 56, 24, 56, 0, 56, T.load("int8", featuremap_buffer2, 0), 0, 0, 0, T.float32(1), 0, "NHWC", 1344, 24, 1, "int8", 56, 56, 24, 56, 0, 56, T.load("int8", output_buffer.data, 0), 0, 0, 0, T.float32(1), 0, "NHWC", 1344, 24, 1, "ADD", 0, "NONE", 0, 0, "TFL", dtype="handle")) + T.evaluate(T.call_extern("ethosu_copy", weight_buffer[0], 2608, weight_global[0], dtype="handle")) + T.evaluate(T.call_extern("ethosu_copy", bias_buffer[0], 240, bias_global[0], dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 56, 56, 96, 56, 0, 56, input_buffer[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 5376, 96, 1, "int8", 56, 56, 24, 56, 0, 56, featuremap_buffer[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 1344, 24, 1, 1, 1, 1, 1, 1, 1, weight_global[0], 2608, 12, bias_global[0], 240, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) + T.evaluate(T.call_extern("ethosu_copy", weight_buffer2[0], 736, weight_global2[0], dtype="handle")) + T.evaluate(T.call_extern("ethosu_copy", bias_buffer2[0], 240, bias_global[0], dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 56, 56, 24, 56, 0, 56, featuremap_buffer[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 1344, 24, 1, "int8", 56, 56, 24, 56, 0, 56, featuremap_buffer2[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 1344, 24, 1, 1, 1, 1, 1, 1, 1, weight_global2[0], 736, 12, bias_global[0], 240, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) + T.evaluate(T.call_extern("ethosu_binary_elementwise", "int8", 56, 56, 24, 56, 0, 56, featuremap_buffer[0], 0, 0, 0, T.float32(1), 0, "NHWC", 1344, 24, 1, "int8", 56, 56, 24, 56, 0, 56, featuremap_buffer2[0], 0, 0, 0, T.float32(1), 0, "NHWC", 1344, 24, 1, "int8", 56, 56, 24, 56, 0, 56, output_buffer[0], 0, 0, 0, T.float32(1), 0, "NHWC", 1344, 24, 1, "ADD", 0, "NONE", 0, 0, "TFL", dtype="handle")) __tvm_meta__ = None # fmt: on diff --git a/tests/python/contrib/test_ethosu/test_tir_to_cs_translator.py b/tests/python/contrib/test_ethosu/test_tir_to_cs_translator.py index de214888be6b..8169f7b86d5b 100644 --- a/tests/python/contrib/test_ethosu/test_tir_to_cs_translator.py +++ b/tests/python/contrib/test_ethosu/test_tir_to_cs_translator.py @@ -33,13 +33,13 @@ @tvm.script.ir_module class SingleEthosUConv2D: @T.prim_func - def main(placeholder_3: T.Buffer[(1, 16, 16, 32), "int8"], ethosu_conv2d_1: T.Buffer[(1, 8, 8, 16), "int8"]) -> None: + def main(placeholder_3: T.Buffer[(8192,), "int8"], ethosu_conv2d_1: T.Buffer[(1024,), "int8"]) -> None: # function attr dict T.func_attr({"global_symbol": "main", "tir.noalias": True}) - placeholder_4 = T.buffer_var("uint8", "") - placeholder_5 = T.buffer_var("uint8", "") + placeholder_4 = T.buffer_decl([1], "uint8") + placeholder_5 = T.buffer_decl([1], "uint8") # body - T.evaluate(T.call_extern("ethosu_conv2d", "uint8", 8, 8, 3, 8, 0, 8, T.load("uint8", placeholder_3.data, 0), 0, 0, 0, T.float32(0.5), 10, "NHWC", 24, 3, 1, "uint8", 8, 8, 16, 8, 0, 8, T.load("uint8", ethosu_conv2d_1.data, 0), 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 16, 1, 1, 1, 1, 1, 1, 1, T.load("uint8", placeholder_4, 0), 0, 12, T.load("uint8", placeholder_5, 0), 0, 0, 0, 0, 0, "CLIP", 0, 255, "TFL", "NONE", dtype="uint8")) + T.evaluate(T.call_extern("ethosu_conv2d", "uint8", 8, 8, 3, 8, 0, 8, placeholder_3[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 24, 3, 1, "uint8", 8, 8, 16, 8, 0, 8, ethosu_conv2d_1[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 16, 1, 1, 1, 1, 1, 1, 1, placeholder_4[0], 0, 12, placeholder_5[0], 0, 0, 0, 0, 0, "CLIP", 0, 255, "TFL", "NONE", dtype="uint8")) # fmt: on @@ -48,20 +48,20 @@ def main(placeholder_3: T.Buffer[(1, 16, 16, 32), "int8"], ethosu_conv2d_1: T.Bu @tvm.script.ir_module class MultiEthosUConv2D: @T.prim_func - def main(placeholder_6: T.Buffer[(1, 8, 8, 3), "int8"], ethosu_conv2d_1: T.Buffer[(1, 8, 8, 8), "int8"]) -> None: + def main(placeholder_6: T.Buffer[(192,), "int8"], ethosu_conv2d_1: T.Buffer[(512,), "int8"]) -> None: # function attr dict T.func_attr({"global_symbol": "main", "tir.noalias": True}) - placeholder_9 = T.buffer_var("uint8", "") - placeholder_7 = T.buffer_var("uint8", "") - placeholder_8 = T.buffer_var("uint8", "") - placeholder_5 = T.buffer_var("uint8", "") + placeholder_9 = T.buffer_decl([1], "uint8") + placeholder_7 = T.buffer_decl([1], "uint8") + placeholder_8 = T.buffer_decl([1], "uint8") + placeholder_5 = T.buffer_decl([1], "uint8") # body ethosu_conv2d_2 = T.allocate([1024], "uint8", "global") ethosu_conv2d_3 = T.allocate([2048], "uint8", "global") - T.evaluate(T.call_extern("ethosu_conv2d", "uint8", 4, 8, 3, 4, 0, 8, T.load("uint8", placeholder_6.data, 0), 0, 0, 0, T.float32(0.5), 10, "NHWC", 24, 3, 1, "uint8", 4, 8, 32, 4, 0, 8, T.load("uint8", ethosu_conv2d_2, 0), 0, 0, 0, T.float32(0.25), 14, "NHWC", 256, 32, 1, 1, 1, 1, 1, 1, 1, T.load("uint8", placeholder_7, 0), 0, 12, T.load("uint8", placeholder_8, 0), 0, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", dtype="uint8")) - T.evaluate(T.call_extern("ethosu_conv2d", "uint8", 4, 8, 32, 4, 0, 8, T.load("uint8", ethosu_conv2d_2, 0), 0, 0, 0, T.float32(0.5), 10, "NHWC", 256, 32, 1, "uint8", 4, 8, 8, 4, 0, 8, T.load("uint8", ethosu_conv2d_1.data, 0), 0, 0, 0, T.float32(0.25), 14, "NHWC", 64, 8, 1, 1, 1, 1, 1, 1, 1, T.load("uint8", placeholder_9, 0), 0, 12, T.load("uint8", placeholder_5, 0), 0, 0, 0, 0, 0, "CLIP", 0, 255, "TFL", "NONE", dtype="uint8")) - T.evaluate(T.call_extern("ethosu_conv2d", "uint8", 4, 8, 3, 4, 0, 8, T.load("uint8", placeholder_6.data, 96), 0, 0, 0, T.float32(0.5), 10, "NHWC", 24, 3, 1, "uint8", 4, 8, 32, 4, 0, 8, T.load("uint8", ethosu_conv2d_2, 0), 0, 0, 0, T.float32(0.25), 14, "NHWC", 256, 32, 1, 1, 1, 1, 1, 1, 1, T.load("uint8", placeholder_7, 0), 0, 12, T.load("uint8", placeholder_8, 0), 0, 0, 0, 0, 0, "CLIP", 0, 255, "TFL", "NONE", dtype="uint8")) - T.evaluate(T.call_extern("ethosu_conv2d", "uint8", 4, 8, 32, 4, 0, 8, T.load("uint8", ethosu_conv2d_2, 0), 0, 0, 0, T.float32(0.5), 10, "NHWC", 256, 32, 1, "uint8", 4, 8, 8, 4, 0, 8, T.load("uint8", ethosu_conv2d_1.data, 256), 0, 0, 0, T.float32(0.25), 14, "NHWC", 64, 8, 1, 1, 1, 1, 1, 1, 1, T.load("uint8", placeholder_9, 0), 0, 12, T.load("uint8", placeholder_5, 0), 0, 0, 0, 0, 0, "CLIP", 0, 255, "TFL", "NONE", dtype="uint8")) + T.evaluate(T.call_extern("ethosu_conv2d", "uint8", 4, 8, 3, 4, 0, 8, placeholder_6[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 24, 3, 1, "uint8", 4, 8, 32, 4, 0, 8, ethosu_conv2d_2[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 256, 32, 1, 1, 1, 1, 1, 1, 1, placeholder_7[0], 0, 12, placeholder_8[0], 0, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", dtype="uint8")) + T.evaluate(T.call_extern("ethosu_conv2d", "uint8", 4, 8, 32, 4, 0, 8, ethosu_conv2d_2[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 256, 32, 1, "uint8", 4, 8, 8, 4, 0, 8, ethosu_conv2d_1[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 64, 8, 1, 1, 1, 1, 1, 1, 1, placeholder_9[0], 0, 12, placeholder_5[0], 0, 0, 0, 0, 0, "CLIP", 0, 255, "TFL", "NONE", dtype="uint8")) + T.evaluate(T.call_extern("ethosu_conv2d", "uint8", 4, 8, 3, 4, 0, 8, placeholder_6[96], 0, 0, 0, T.float32(0.5), 10, "NHWC", 24, 3, 1, "uint8", 4, 8, 32, 4, 0, 8, ethosu_conv2d_2[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 256, 32, 1, 1, 1, 1, 1, 1, 1, placeholder_7[0], 0, 12, placeholder_8[0], 0, 0, 0, 0, 0, "CLIP", 0, 255, "TFL", "NONE", dtype="uint8")) + T.evaluate(T.call_extern("ethosu_conv2d", "uint8", 4, 8, 32, 4, 0, 8, ethosu_conv2d_2[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 256, 32, 1, "uint8", 4, 8, 8, 4, 0, 8, ethosu_conv2d_1[256], 0, 0, 0, T.float32(0.25), 14, "NHWC", 64, 8, 1, 1, 1, 1, 1, 1, 1, placeholder_9[0], 0, 12, placeholder_5[0], 0, 0, 0, 0, 0, "CLIP", 0, 255, "TFL", "NONE", dtype="uint8")) # fmt: on @@ -70,17 +70,17 @@ def main(placeholder_6: T.Buffer[(1, 8, 8, 3), "int8"], ethosu_conv2d_1: T.Buffe @tvm.script.ir_module class MultiEthosUCopy: @T.prim_func - def main(placeholder_3: T.Buffer[(1, 16, 16, 32), "int8"], ethosu_conv2d_1: T.Buffer[(1, 16, 16, 8), "int8"]) -> None: + def main(placeholder_3: T.Buffer[(8192,), "int8"], ethosu_conv2d_1: T.Buffer[(2048,), "int8"]) -> None: # function attr dict T.func_attr({"global_symbol": "main", "tir.noalias": True}) - placeholder_5 = T.buffer_var("uint8", "") - placeholder_4 = T.buffer_var("uint8", "") + placeholder_5 = T.buffer_decl([1], "int32") + placeholder_4 = T.buffer_decl([1], "uint8") # body placeholder_global = T.allocate([256], "uint8", "global") placeholder_d_global = T.allocate([8], "int32", "global") - T.evaluate(T.call_extern("ethosu_copy", T.load("uint8", placeholder_4, 0), 256, T.load("uint8", placeholder_global, 0), dtype="handle")) - T.evaluate(T.call_extern("ethosu_copy", T.load("int32", placeholder_5, 0), 8, T.load("int32", placeholder_d_global, 0), dtype="handle")) - T.evaluate(T.call_extern("ethosu_conv2d", "uint8", 16, 16, 32, 16, 0, 16, T.load("uint8", placeholder_3.data, 0), 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "uint8", 16, 16, 8, 16, 0, 16, T.load("uint8", ethosu_conv2d_1.data, 0), 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, T.load("uint8", placeholder_global, 0), 0, 12, T.load("uint8", placeholder_d_global, 0), 0, 0, 0, 0, 0, "CLIP", 0, 255, "TFL", "NONE", dtype="handle")) + T.evaluate(T.call_extern("ethosu_copy", placeholder_4[0], 256, placeholder_global[0], dtype="handle")) + T.evaluate(T.call_extern("ethosu_copy", placeholder_5[0], 8, placeholder_d_global[0], dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "uint8", 16, 16, 32, 16, 0, 16, placeholder_3[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "uint8", 16, 16, 8, 16, 0, 16, ethosu_conv2d_1[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, placeholder_global[0], 0, 12, placeholder_d_global[0], 0, 0, 0, 0, 0, "CLIP", 0, 255, "TFL", "NONE", dtype="handle")) # fmt: on @@ -89,15 +89,15 @@ def main(placeholder_3: T.Buffer[(1, 16, 16, 32), "int8"], ethosu_conv2d_1: T.Bu @tvm.script.ir_module class WeightStreamOnly: @T.prim_func - def main(placeholder: T.Buffer[(1, 16, 16, 32), "int8"], ethosu_write: T.Buffer[(1, 16, 16, 8), "int8"]) -> None: - buffer = T.buffer_var("uint8", "") - buffer_1 = T.buffer_var("uint8", "") - buffer_2 = T.buffer_var("uint8", "") - buffer_3 = T.buffer_var("uint8", "") - buffer_4 = T.buffer_var("uint8", "") - buffer_5 = T.buffer_var("uint8", "") - buffer_6 = T.buffer_var("uint8", "") - buffer_7 = T.buffer_var("uint8", "") + def main(placeholder: T.Buffer[(8192,), "int8"], ethosu_write: T.Buffer[(2048,), "int8"]) -> None: + buffer = T.buffer_decl([1], "uint8") + buffer_1 = T.buffer_decl([1], "uint8") + buffer_2 = T.buffer_decl([1], "uint8") + buffer_3 = T.buffer_decl([1], "uint8") + buffer_4 = T.buffer_decl([1], "uint8") + buffer_5 = T.buffer_decl([1], "uint8") + buffer_6 = T.buffer_decl([1], "uint8") + buffer_7 = T.buffer_decl([1], "uint8") # function attr dict T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True, @@ -112,18 +112,18 @@ def main(placeholder: T.Buffer[(1, 16, 16, 32), "int8"], ethosu_write: T.Buffer[ # body placeholder_global = T.allocate([128], "uint8", "global", annotations={"disable_lower_builtin":True}) placeholder_d_global = T.allocate([32], "uint8", "global", annotations={"disable_lower_builtin":True}) - T.evaluate(T.call_extern("ethosu_copy", T.load("uint8", buffer, 0), 128, T.load("uint8", placeholder_global, 0), dtype="handle")) - T.evaluate(T.call_extern("ethosu_copy", T.load("uint8", buffer_1, 0), 32, T.load("uint8", placeholder_d_global, 0), dtype="handle")) - T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 32, 16, 0, 16, T.load("int8", placeholder.data, 0), 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 16, 2, 16, 0, 16, T.load("int8", ethosu_write.data, 0), 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, T.load("uint8", placeholder_global, 0), 128, 12, T.load("uint8", placeholder_d_global, 0), 32, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) - T.evaluate(T.call_extern("ethosu_copy", T.load("uint8", buffer_2, 0), 112, T.load("uint8", placeholder_global, 0), dtype="handle")) - T.evaluate(T.call_extern("ethosu_copy", T.load("uint8", buffer_3, 0), 32, T.load("uint8", placeholder_d_global, 0), dtype="handle")) - T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 32, 16, 0, 16, T.load("int8", placeholder.data, 0), 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 16, 2, 16, 0, 16, T.load("int8", ethosu_write.data, 2), 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, T.load("uint8", placeholder_global, 0), 112, 12, T.load("uint8", placeholder_d_global, 0), 32, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) - T.evaluate(T.call_extern("ethosu_copy", T.load("uint8", buffer_4, 0), 112, T.load("uint8", placeholder_global, 0), dtype="handle")) - T.evaluate(T.call_extern("ethosu_copy", T.load("uint8", buffer_5, 0), 32, T.load("uint8", placeholder_d_global, 0), dtype="handle")) - T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 32, 16, 0, 16, T.load("int8", placeholder.data, 0), 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 16, 2, 16, 0, 16, T.load("int8", ethosu_write.data, 4), 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, T.load("uint8", placeholder_global, 0), 112, 12, T.load("uint8", placeholder_d_global, 0), 32, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) - T.evaluate(T.call_extern("ethosu_copy", T.load("uint8", buffer_6, 0), 112, T.load("uint8", placeholder_global, 0), dtype="handle")) - T.evaluate(T.call_extern("ethosu_copy", T.load("uint8", buffer_7, 0), 32, T.load("uint8", placeholder_d_global, 0), dtype="handle")) - T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 32, 16, 0, 16, T.load("int8", placeholder.data, 0), 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 16, 2, 16, 0, 16, T.load("int8", ethosu_write.data, 6), 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, T.load("uint8", placeholder_global, 0), 112, 12, T.load("uint8", placeholder_d_global, 0), 32, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) + T.evaluate(T.call_extern("ethosu_copy", buffer[0], 128, placeholder_global[0], dtype="handle")) + T.evaluate(T.call_extern("ethosu_copy", buffer_1[0], 32, placeholder_d_global[0], dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 32, 16, 0, 16, placeholder[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 16, 2, 16, 0, 16, ethosu_write[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, placeholder_global[0], 128, 12, placeholder_d_global[0], 32, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) + T.evaluate(T.call_extern("ethosu_copy", buffer_2[0], 112, placeholder_global[0], dtype="handle")) + T.evaluate(T.call_extern("ethosu_copy", buffer_3[0], 32, placeholder_d_global[0], dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 32, 16, 0, 16, placeholder[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 16, 2, 16, 0, 16, ethosu_write[2], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, placeholder_global[0], 112, 12, placeholder_d_global[0], 32, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) + T.evaluate(T.call_extern("ethosu_copy", buffer_4[0], 112, placeholder_global[0], dtype="handle")) + T.evaluate(T.call_extern("ethosu_copy", buffer_5[0], 32, placeholder_d_global[0], dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 32, 16, 0, 16, placeholder[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 16, 2, 16, 0, 16, ethosu_write[4], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, placeholder_global[0], 112, 12, placeholder_d_global[0], 32, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) + T.evaluate(T.call_extern("ethosu_copy", buffer_6[0], 112, placeholder_global[0], dtype="handle")) + T.evaluate(T.call_extern("ethosu_copy", buffer_7[0], 32, placeholder_d_global[0], dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 32, 16, 0, 16, placeholder[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 16, 2, 16, 0, 16, ethosu_write[6], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, placeholder_global[0], 112, 12, placeholder_d_global[0], 32, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) __tvm_meta__ = None # fmt: on @@ -133,17 +133,17 @@ def main(placeholder: T.Buffer[(1, 16, 16, 32), "int8"], ethosu_write: T.Buffer[ @tvm.script.ir_module class MixedRead: @T.prim_func - def main(placeholder: T.Buffer[(1, 16, 16, 32), "int8"], ethosu_write: T.Buffer[(1, 16, 16, 8), "int8"]) -> None: - buffer = T.buffer_var("uint8", "") - buffer_1 = T.buffer_var("uint8", "") - buffer_2 = T.buffer_var("uint8", "") - buffer_3 = T.buffer_var("uint8", "") - buffer_4 = T.buffer_var("uint8", "") - buffer_5 = T.buffer_var("uint8", "") - buffer_6 = T.buffer_var("uint8", "") - buffer_7 = T.buffer_var("uint8", "") - buffer_8 = T.buffer_var("uint8", "") - buffer_9 = T.buffer_var("uint8", "") + def main(placeholder: T.Buffer[(8192,), "int8"], ethosu_write: T.Buffer[(2048,), "int8"]) -> None: + buffer = T.buffer_decl([1], "uint8") + buffer_1 = T.buffer_decl([1], "uint8") + buffer_2 = T.buffer_decl([1], "uint8") + buffer_3 = T.buffer_decl([1], "uint8") + buffer_4 = T.buffer_decl([1], "uint8") + buffer_5 = T.buffer_decl([1], "uint8") + buffer_6 = T.buffer_decl([1], "uint8") + buffer_7 = T.buffer_decl([1], "uint8") + buffer_8 = T.buffer_decl([1], "uint8") + buffer_9 = T.buffer_decl([1], "uint8") # function attr dict T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True, @@ -161,19 +161,19 @@ def main(placeholder: T.Buffer[(1, 16, 16, 32), "int8"], ethosu_write: T.Buffer[ ethosu_write_1 = T.allocate([4096], "int8", "global", annotations={"disable_lower_builtin":True}) placeholder_global = T.allocate([80], "uint8", "global", annotations={"disable_lower_builtin":True}) placeholder_d_global = T.allocate([32], "uint8", "global", annotations={"disable_lower_builtin":True}) - T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 32, 16, 0, 16, T.load("int8", placeholder.data, 0), 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 16, 16, 16, 0, 16, T.load("int8", ethosu_write_1, 0), 0, 0, 0, T.float32(0.25), 14, "NHWC", 256, 16, 1, 1, 1, 1, 1, 1, 1, T.load("uint8", buffer, 0), 592, 12, T.load("uint8", buffer_1, 0), 160, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) - T.evaluate(T.call_extern("ethosu_copy", T.load("uint8", buffer_2, 0), 80, T.load("uint8", placeholder_global, 0), dtype="handle")) - T.evaluate(T.call_extern("ethosu_copy", T.load("uint8", buffer_3, 0), 32, T.load("uint8", placeholder_d_global, 0), dtype="handle")) - T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 16, 16, 0, 16, T.load("int8", ethosu_write_1, 0), 0, 0, 0, T.float32(0.5), 10, "NHWC", 256, 16, 1, "int8", 16, 16, 2, 16, 0, 16, T.load("int8", ethosu_write.data, 0), 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, T.load("uint8", placeholder_global, 0), 80, 12, T.load("uint8", placeholder_d_global, 0), 32, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) - T.evaluate(T.call_extern("ethosu_copy", T.load("uint8", buffer_4, 0), 80, T.load("uint8", placeholder_global, 0), dtype="handle")) - T.evaluate(T.call_extern("ethosu_copy", T.load("uint8", buffer_5, 0), 32, T.load("uint8", placeholder_d_global, 0), dtype="handle")) - T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 16, 16, 0, 16, T.load("int8", ethosu_write_1, 0), 0, 0, 0, T.float32(0.5), 10, "NHWC", 256, 16, 1, "int8", 16, 16, 2, 16, 0, 16, T.load("int8", ethosu_write.data, 2), 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, T.load("uint8", placeholder_global, 0), 80, 12, T.load("uint8", placeholder_d_global, 0), 32, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) - T.evaluate(T.call_extern("ethosu_copy", T.load("uint8", buffer_6, 0), 80, T.load("uint8", placeholder_global, 0), dtype="handle")) - T.evaluate(T.call_extern("ethosu_copy", T.load("uint8", buffer_7, 0), 32, T.load("uint8", placeholder_d_global, 0), dtype="handle")) - T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 16, 16, 0, 16, T.load("int8", ethosu_write_1, 0), 0, 0, 0, T.float32(0.5), 10, "NHWC", 256, 16, 1, "int8", 16, 16, 2, 16, 0, 16, T.load("int8", ethosu_write.data, 4), 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, T.load("uint8", placeholder_global, 0), 80, 12, T.load("uint8", placeholder_d_global, 0), 32, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) - T.evaluate(T.call_extern("ethosu_copy", T.load("uint8", buffer_8, 0), 80, T.load("uint8", placeholder_global, 0), dtype="handle")) - T.evaluate(T.call_extern("ethosu_copy", T.load("uint8", buffer_9, 0), 32, T.load("uint8", placeholder_d_global, 0), dtype="handle")) - T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 16, 16, 0, 16, T.load("int8", ethosu_write_1, 0), 0, 0, 0, T.float32(0.5), 10, "NHWC", 256, 16, 1, "int8", 16, 16, 2, 16, 0, 16, T.load("int8", ethosu_write.data, 6), 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, T.load("uint8", placeholder_global, 0), 80, 12, T.load("uint8", placeholder_d_global, 0), 32, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 32, 16, 0, 16, placeholder[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 16, 16, 16, 0, 16, ethosu_write_1[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 256, 16, 1, 1, 1, 1, 1, 1, 1, buffer[0], 592, 12, buffer_1[0], 160, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) + T.evaluate(T.call_extern("ethosu_copy", buffer_2[0], 80, placeholder_global[0], dtype="handle")) + T.evaluate(T.call_extern("ethosu_copy", buffer_3[0], 32, placeholder_d_global[0], dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 16, 16, 0, 16, ethosu_write_1[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 256, 16, 1, "int8", 16, 16, 2, 16, 0, 16, ethosu_write[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, placeholder_global[0], 80, 12, placeholder_d_global[0], 32, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) + T.evaluate(T.call_extern("ethosu_copy", buffer_4[0], 80, placeholder_global[0], dtype="handle")) + T.evaluate(T.call_extern("ethosu_copy", buffer_5[0], 32, placeholder_d_global[0], dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 16, 16, 0, 16, ethosu_write_1[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 256, 16, 1, "int8", 16, 16, 2, 16, 0, 16, ethosu_write[2], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, placeholder_global[0], 80, 12, placeholder_d_global[0], 32, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) + T.evaluate(T.call_extern("ethosu_copy", buffer_6[0], 80, placeholder_global[0], dtype="handle")) + T.evaluate(T.call_extern("ethosu_copy", buffer_7[0], 32, placeholder_d_global[0], dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 16, 16, 0, 16, ethosu_write_1[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 256, 16, 1, "int8", 16, 16, 2, 16, 0, 16, ethosu_write[4], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, placeholder_global[0], 80, 12, placeholder_d_global[0], 32, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) + T.evaluate(T.call_extern("ethosu_copy", buffer_8[0], 80, placeholder_global[0], dtype="handle")) + T.evaluate(T.call_extern("ethosu_copy", buffer_9[0], 32, placeholder_d_global[0], dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 16, 16, 0, 16, ethosu_write_1[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 256, 16, 1, "int8", 16, 16, 2, 16, 0, 16, ethosu_write[6], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, placeholder_global[0], 80, 12, placeholder_d_global[0], 32, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) __tvm_meta__ = None # fmt: on @@ -523,12 +523,12 @@ class SingleEthosuDepthwiseConv2D: def main(placeholder: T.handle, placeholder_1: T.handle, placeholder_2: T.handle, ethosu_depthwise_conv2d: T.handle) -> None: # function attr dict T.func_attr({"global_symbol": "main", "tir.noalias": True}) - placeholder_4 = T.match_buffer(placeholder_1, [3, 3, 2, 1], dtype="int8", elem_offset=0, align=128, offset_factor=1) - placeholder_5 = T.match_buffer(placeholder_2, [3, 10], dtype="uint8", elem_offset=0, align=128, offset_factor=1) - placeholder_3 = T.match_buffer(placeholder, [1, 8, 8, 3], dtype="int8", elem_offset=0, align=128, offset_factor=1) - ethosu_depthwise_conv2d_1 = T.match_buffer(ethosu_depthwise_conv2d, [1, 6, 7, 3], dtype="int8", elem_offset=0, align=128, offset_factor=1) + placeholder_4 = T.match_buffer(placeholder_1, [18], dtype="int8", elem_offset=0, align=128, offset_factor=1) + placeholder_5 = T.match_buffer(placeholder_2, [30], dtype="uint8", elem_offset=0, align=128, offset_factor=1) + placeholder_3 = T.match_buffer(placeholder, [192], dtype="int8", elem_offset=0, align=128, offset_factor=1) + ethosu_depthwise_conv2d_1 = T.match_buffer(ethosu_depthwise_conv2d, [126], dtype="int8", elem_offset=0, align=128, offset_factor=1) # body - T.evaluate(T.call_extern("ethosu_depthwise_conv2d", "int8", 8, 8, 3, 8, 0, 8, T.load("int8", placeholder_3.data, 0), 0, 0, 0, T.float32(0.6), 11, "NHWC", 24, 3, 1, "int8", 6, 7, 3, 6, 0, 7, T.load("int8", ethosu_depthwise_conv2d_1.data, 0), 0, 0, 0, T.float32(0.26), 15, "NHWC", 21, 3, 1, 2, 3, 1, 1, 1, 1, T.load("int8", placeholder_4.data, 0), 18, 13, T.load("uint8", placeholder_5.data, 0), 30, 0, 0, 0, 0, "CLIP", 15, 105, "TFL", "NONE", dtype="int8")) + T.evaluate(T.call_extern("ethosu_depthwise_conv2d", "int8", 8, 8, 3, 8, 0, 8, placeholder_3[0], 0, 0, 0, T.float32(0.6), 11, "NHWC", 24, 3, 1, "int8", 6, 7, 3, 6, 0, 7, ethosu_depthwise_conv2d_1[0], 0, 0, 0, T.float32(0.26), 15, "NHWC", 21, 3, 1, 2, 3, 1, 1, 1, 1, placeholder_4[0], 18, 13, placeholder_5[0], 30, 0, 0, 0, 0, "CLIP", 15, 105, "TFL", "NONE", dtype="int8")) __tvm_meta__ = None # fmt: on @@ -655,8 +655,8 @@ def populate_ethosu_copy_calls(stmt): ethosu_copy_calls = extract_ethosu_copy_extern_calls(test_case["tir_module"]) for idx, ethosu_copy_call in enumerate(ethosu_copy_calls): npu_dma_op = tir_to_cs_translator.translate_ethosu_tir_call_extern(ethosu_copy_call) - assert npu_dma_op.src.address.buffer_var.name == test_case["ref"][idx]["src"] - assert npu_dma_op.dest.address.buffer_var.name == test_case["ref"][idx]["dest"] + assert npu_dma_op.src.address.buffer.name == test_case["ref"][idx]["src"] + assert npu_dma_op.dest.address.buffer.name == test_case["ref"][idx]["dest"] assert npu_dma_op.src.length == test_case["ref"][idx]["length"] assert npu_dma_op.dest.length == test_case["ref"][idx]["length"] @@ -665,10 +665,10 @@ def populate_ethosu_copy_calls(stmt): @tvm.script.ir_module class MixedConstantDatatypes: @T.prim_func - def main(placeholder_4: T.Buffer[(1, 8, 16, 16), "int8"], ethosu_write_1: T.Buffer[(1, 1, 1, 16), "int8"]) -> None: - buffer = T.buffer_var("uint8", "") - buffer_1 = T.buffer_var("uint8", "") - buffer_2 = T.buffer_var("int16", "") + def main(placeholder_4: T.Buffer[(2048,), "int8"], ethosu_write_1: T.Buffer[(16,), "int8"]) -> None: + buffer = T.buffer_decl([1], "uint8") + buffer_1 = T.buffer_decl([1], "uint8") + buffer_2 = T.buffer_decl([1], "int16") # function attr dict T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True, @@ -680,11 +680,11 @@ def main(placeholder_4: T.Buffer[(1, 8, 16, 16), "int8"], ethosu_write_1: T.Buff placeholder_d_global = T.allocate([160], "uint8", "global") ethosu_write_2 = T.allocate([16], "int16", "global") placeholder_d_global_1 = T.allocate([1], "int16", "global") - T.evaluate(T.call_extern("ethosu_copy", T.load("uint8", buffer_1, 0), 272, T.load("uint8", placeholder_global, 0), dtype="uint8")) - T.evaluate(T.call_extern("ethosu_copy", T.load("uint8", buffer, 0), 160, T.load("uint8", placeholder_d_global, 0), dtype="uint8")) - T.evaluate(T.call_extern("ethosu_depthwise_conv2d", "int8", 8, 16, 16, 8, 0, 16, T.load("int8", placeholder_4.data, 0), 0, 0, 0, T.float32(0.0039215548895299435), -128, "NHWC", 256, 16, 1, "int16", 1, 1, 16, 1, 0, 1, T.load("int16", ethosu_write_2, 0), 0, 0, 0, T.float32(0.0023205536417663097), -128, "NHWC", 1, 1, 1, 16, 8, 1, 1, 1, 1, T.load("uint8", placeholder_global, 0), 272, 0, T.load("uint8", placeholder_d_global, 0), 160, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", dtype="int16")) - T.evaluate(T.call_extern("ethosu_copy", T.load("int16", buffer_2, 0), 1, T.load("int16", placeholder_d_global_1, 0), dtype="int16")) - T.evaluate(T.call_extern("ethosu_binary_elementwise", "int16", 1, 1, 16, 1, 0, 1, T.load("int16", ethosu_write_2, 0), 0, 0, 0, T.float32(0.0023205536417663097), -128, "NHWC", 1, 1, 1, "int16", 1, 1, 1, 1, 0, 1, T.load("int16", placeholder_d_global_1, 0), 0, 0, 0, T.float32(0.0078125018482064768), 0, "NHWC", 1, 1, 1, "int8", 1, 1, 16, 1, 0, 1, T.load("int8", ethosu_write_1.data, 0), 0, 0, 0, T.float32(0.0023205536417663097), -128, "NHWC", 1, 1, 1, "MUL", 0, "NONE", 0, 0, "NATURAL", dtype="int8")) + T.evaluate(T.call_extern("ethosu_copy", buffer_1[0], 272, placeholder_global[0], dtype="uint8")) + T.evaluate(T.call_extern("ethosu_copy", buffer[0], 160, placeholder_d_global[0], dtype="uint8")) + T.evaluate(T.call_extern("ethosu_depthwise_conv2d", "int8", 8, 16, 16, 8, 0, 16, placeholder_4[0], 0, 0, 0, T.float32(0.0039215548895299435), -128, "NHWC", 256, 16, 1, "int16", 1, 1, 16, 1, 0, 1, ethosu_write_2[0], 0, 0, 0, T.float32(0.0023205536417663097), -128, "NHWC", 1, 1, 1, 16, 8, 1, 1, 1, 1, placeholder_global[0], 272, 0, placeholder_d_global[0], 160, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", dtype="int16")) + T.evaluate(T.call_extern("ethosu_copy", buffer_2[0], 1, placeholder_d_global_1[0], dtype="int16")) + T.evaluate(T.call_extern("ethosu_binary_elementwise", "int16", 1, 1, 16, 1, 0, 1, ethosu_write_2[0], 0, 0, 0, T.float32(0.0023205536417663097), -128, "NHWC", 1, 1, 1, "int16", 1, 1, 1, 1, 0, 1, placeholder_d_global_1[0], 0, 0, 0, T.float32(0.0078125018482064768), 0, "NHWC", 1, 1, 1, "int8", 1, 1, 16, 1, 0, 1, ethosu_write_1[0], 0, 0, 0, T.float32(0.0023205536417663097), -128, "NHWC", 1, 1, 1, "MUL", 0, "NONE", 0, 0, "NATURAL", dtype="int8")) # fmt: on @@ -901,11 +901,11 @@ def check_buffer(address, region, length, buffer_var): for npu_op in npu_ops: if isinstance(npu_op, vapi.NpuDmaOperation): - src_tir_buffer_var = npu_op_tir_buffers[npu_op][0].buffer_var + src_tir_buffer_var = npu_op_tir_buffers[npu_op][0].buffer.data check_buffer( npu_op.src.address, npu_op.src.region, npu_op.src.length, src_tir_buffer_var ) - dest_tir_load = npu_op_tir_buffers[npu_op][1].buffer_var + dest_tir_load = npu_op_tir_buffers[npu_op][1].buffer.data check_buffer( npu_op.dest.address, npu_op.dest.region, @@ -913,7 +913,7 @@ def check_buffer(address, region, length, buffer_var): dest_tir_load, ) elif issubclass(type(npu_op), vapi.NpuBlockOperation): - ifm_tir_buffer_var = npu_op_tir_buffers[npu_op][0].buffer_var + ifm_tir_buffer_var = npu_op_tir_buffers[npu_op][0].buffer.data ifm_length = ( npu_op.ifm.shape.height * npu_op.ifm.shape.width * npu_op.ifm.shape.depth ) @@ -923,7 +923,7 @@ def check_buffer(address, region, length, buffer_var): ifm_length, ifm_tir_buffer_var, ) - ofm_tir_buffer_var = npu_op_tir_buffers[npu_op][1].buffer_var + ofm_tir_buffer_var = npu_op_tir_buffers[npu_op][1].buffer.data ofm_length = ( npu_op.ofm.shape.height * npu_op.ofm.shape.width * npu_op.ofm.shape.depth ) @@ -939,7 +939,7 @@ def check_buffer(address, region, length, buffer_var): npu_op.weights[idx].address, npu_op.weights[idx].region, npu_op.weights[idx].length, - weight.address.buffer_var, + weight.address.buffer.data, ) for idx, bias in enumerate(npu_op_tir_buffers[npu_op][3]): assert isinstance(bias, vapi.NpuAddressRange) @@ -947,7 +947,7 @@ def check_buffer(address, region, length, buffer_var): npu_op.biases[idx].address, npu_op.biases[idx].region, npu_op.biases[idx].length, - bias.address.buffer_var, + bias.address.buffer.data, ) for test_case in test_cases: @@ -989,10 +989,10 @@ class SingleEthosuPooling: def main(placeholder: T.handle, placeholder_3: T.handle, ethosu_write: T.handle) -> None: # function attr dict T.func_attr({"global_symbol": "main", "tir.noalias": True}) - placeholder_4 = T.match_buffer(placeholder, [1, 5, 9, 3], dtype="int8", elem_offset=0, align=128, offset_factor=1) - ethosu_write_2 = T.match_buffer(ethosu_write, [1, 5, 5, 3], dtype="int8", elem_offset=0, align=128, offset_factor=1) + placeholder_4 = T.match_buffer(placeholder, [135], dtype="int8", elem_offset=0, align=128, offset_factor=1) + ethosu_write_2 = T.match_buffer(ethosu_write, [75], dtype="int8", elem_offset=0, align=128, offset_factor=1) # body - T.evaluate(T.call_extern("ethosu_pooling", "int8", 5, 9, 3, 5, 0, 9, T.load("int8", placeholder_4.data, 0), 0, 0, 0, T.float32(1.0), 0, "NHWC", 27, 3, 1, "int8", 5, 5, 3, 5, 0, 5, T.load("int8", ethosu_write_2.data, 0), 0, 0, 0, T.float32(1.0), 0, "NHWC", 15, 3, 1, "AVG", 2, 3, 2, 1, 1, 1, 1, 1, 1, 0, "CLIP", 10, 100, "TFL", "NONE", dtype="int8")) + T.evaluate(T.call_extern("ethosu_pooling", "int8", 5, 9, 3, 5, 0, 9, placeholder_4[0], 0, 0, 0, T.float32(1.0), 0, "NHWC", 27, 3, 1, "int8", 5, 5, 3, 5, 0, 5, ethosu_write_2[0], 0, 0, 0, T.float32(1.0), 0, "NHWC", 15, 3, 1, "AVG", 2, 3, 2, 1, 1, 1, 1, 1, 1, 0, "CLIP", 10, 100, "TFL", "NONE", dtype="int8")) __tvm_meta__ = None # fmt: on @@ -1066,10 +1066,10 @@ def main(placeholder: T.handle, ethosu_write: T.handle) -> None: placeholder, [270], dtype="int8", elem_offset=0, align=128, offset_factor=1 ) ethosu_write_2 = T.match_buffer( - ethosu_write, [1, 5, 9, 3], dtype="int8", elem_offset=0, align=128, offset_factor=1 + ethosu_write, [135], dtype="int8", elem_offset=0, align=128, offset_factor=1 ) # body - T.evaluate(T.call_extern( "ethosu_binary_elementwise", "int8", 5, 9, 3, 5, 0, 9, T.load("int8", placeholder_2.data, 0), 0, 0, 0, T.float32(1.0), 0, "NHWC", 27, 3, 1, "int8", 5, 9, 3, 5, 0, 9, T.load("int8", placeholder_2.data, 135), 0, 0, 0, T.float32(1.0), 0, "NHWC", 27, 3, 1, "int8", 5, 9, 3, 5, 0, 9, T.load("int8", ethosu_write_2.data, 0), 0, 0, 0, T.float32(1.0), 0, "NHWC", 27, 3, 1, "ADD", 0, "CLIP", 10, 100, "TFL", dtype="int8")) + T.evaluate(T.call_extern( "ethosu_binary_elementwise", "int8", 5, 9, 3, 5, 0, 9, placeholder_2[0], 0, 0, 0, T.float32(1.0), 0, "NHWC", 27, 3, 1, "int8", 5, 9, 3, 5, 0, 9, placeholder_2[135], 0, 0, 0, T.float32(1.0), 0, "NHWC", 27, 3, 1, "int8", 5, 9, 3, 5, 0, 9, ethosu_write_2[0], 0, 0, 0, T.float32(1.0), 0, "NHWC", 27, 3, 1, "ADD", 0, "CLIP", 10, 100, "TFL", dtype="int8")) __tvm_meta__ = None # fmt: on @@ -1083,9 +1083,9 @@ def main(placeholder: T.handle, ethosu_write: T.handle) -> None: # function attr dict T.func_attr({"global_symbol": "main", "tir.noalias": True}) placeholder_2 = T.match_buffer(placeholder, [270], dtype="int8", elem_offset=0, align=128, offset_factor=1) - ethosu_write_2 = T.match_buffer(ethosu_write, [1, 5, 9, 3], dtype="int8", elem_offset=0, align=128, offset_factor=1) + ethosu_write_2 = T.match_buffer(ethosu_write, [135], dtype="int8", elem_offset=0, align=128, offset_factor=1) # body - T.evaluate(T.call_extern("ethosu_binary_elementwise", "int8", 5, 9, 3, 5, 0, 9, T.load("int8", placeholder_2.data, 0), 0, 0, 0, T.float32(1.0), 0, "NHWC", 27, 3, 1, "int8", 5, 9, 3, 5, 0, 9, T.load("int8", placeholder_2.data, 135), 0, 0, 0, T.float32(1.0), 0, "NHWC", 27, 3, 1, "int8", 5, 9, 3, 5, 0, 9, T.load("int8", ethosu_write_2.data, 0), 0, 0, 0, T.float32(1.0), 0, "NHWC", 27, 3, 1, "SUB", 0, "CLIP", 10, 100, "TFL", dtype="int8")) + T.evaluate(T.call_extern("ethosu_binary_elementwise", "int8", 5, 9, 3, 5, 0, 9, placeholder_2[0], 0, 0, 0, T.float32(1.0), 0, "NHWC", 27, 3, 1, "int8", 5, 9, 3, 5, 0, 9, placeholder_2[135], 0, 0, 0, T.float32(1.0), 0, "NHWC", 27, 3, 1, "int8", 5, 9, 3, 5, 0, 9, ethosu_write_2[0], 0, 0, 0, T.float32(1.0), 0, "NHWC", 27, 3, 1, "SUB", 0, "CLIP", 10, 100, "TFL", dtype="int8")) __tvm_meta__ = None # fmt: on @@ -1098,9 +1098,9 @@ def main(placeholder: T.handle, ethosu_write: T.handle) -> None: # function attr dict T.func_attr({"global_symbol": "main", "tir.noalias": True}) placeholder_2 = T.match_buffer(placeholder, [270], dtype="int8", elem_offset=0, align=128, offset_factor=1) - ethosu_write_2 = T.match_buffer(ethosu_write, [1, 5, 9, 3], dtype="int8", elem_offset=0, align=128, offset_factor=1) + ethosu_write_2 = T.match_buffer(ethosu_write, [135], dtype="int8", elem_offset=0, align=128, offset_factor=1) # body - T.evaluate(T.call_extern("ethosu_binary_elementwise", "int8", 5, 9, 3, 5, 0, 9, T.load("int8", placeholder_2.data, 0), 0, 0, 0, T.float32(1.0), 0, "NHWC", 27, 3, 1, "int8", 5, 9, 3, 5, 0, 9, T.load("int8", placeholder_2.data, 135), 0, 0, 0, T.float32(1.0), 0, "NHWC", 27, 3, 1, "int8", 5, 9, 3, 5, 0, 9, T.load("int8", ethosu_write_2.data, 0), 0, 0, 0, T.float32(1.0), 0, "NHWC", 27, 3, 1, "MUL", 0, "CLIP", 10, 100, "TFL", dtype="int8")) + T.evaluate(T.call_extern("ethosu_binary_elementwise", "int8", 5, 9, 3, 5, 0, 9, placeholder_2[0], 0, 0, 0, T.float32(1.0), 0, "NHWC", 27, 3, 1, "int8", 5, 9, 3, 5, 0, 9, placeholder_2[135], 0, 0, 0, T.float32(1.0), 0, "NHWC", 27, 3, 1, "int8", 5, 9, 3, 5, 0, 9, ethosu_write_2[0], 0, 0, 0, T.float32(1.0), 0, "NHWC", 27, 3, 1, "MUL", 0, "CLIP", 10, 100, "TFL", dtype="int8")) __tvm_meta__ = None # fmt: on @@ -1114,9 +1114,9 @@ def main(placeholder: T.handle, ethosu_write: T.handle) -> None: # function attr dict T.func_attr({"global_symbol": "main", "tir.noalias": True}) placeholder_2 = T.match_buffer(placeholder, [270], dtype="int8", elem_offset=0, align=128, offset_factor=1) - ethosu_write_2 = T.match_buffer(ethosu_write, [1, 5, 9, 3], dtype="int8", elem_offset=0, align=128, offset_factor=1) + ethosu_write_2 = T.match_buffer(ethosu_write, [135], dtype="int8", elem_offset=0, align=128, offset_factor=1) # body - T.evaluate(T.call_extern("ethosu_binary_elementwise", "int8", 5, 9, 3, 5, 0, 9, T.load("int8", placeholder_2.data, 0), 0, 0, 0, T.float32(1.0), 0, "NHWC", 27, 3, 1, "int8", 5, 9, 3, 5, 0, 9, T.load("int8", placeholder_2.data, 135), 0, 0, 0, T.float32(1.0), 0, "NHWC", 27, 3, 1, "int8", 5, 9, 3, 5, 0, 9, T.load("int8", ethosu_write_2.data, 0), 0, 0, 0, T.float32(1.0), 0, "NHWC", 27, 3, 1, "MIN", 0, "CLIP", 10, 100, "TFL", dtype="int8")) + T.evaluate(T.call_extern("ethosu_binary_elementwise", "int8", 5, 9, 3, 5, 0, 9, placeholder_2[0], 0, 0, 0, T.float32(1.0), 0, "NHWC", 27, 3, 1, "int8", 5, 9, 3, 5, 0, 9, placeholder_2[135], 0, 0, 0, T.float32(1.0), 0, "NHWC", 27, 3, 1, "int8", 5, 9, 3, 5, 0, 9, ethosu_write_2[0], 0, 0, 0, T.float32(1.0), 0, "NHWC", 27, 3, 1, "MIN", 0, "CLIP", 10, 100, "TFL", dtype="int8")) __tvm_meta__ = None # fmt: on @@ -1130,9 +1130,9 @@ def main(placeholder: T.handle, ethosu_write: T.handle) -> None: # function attr dict T.func_attr({"global_symbol": "main", "tir.noalias": True}) placeholder_2 = T.match_buffer(placeholder, [270], dtype="int8", elem_offset=0, align=128, offset_factor=1) - ethosu_write_2 = T.match_buffer(ethosu_write, [1, 5, 9, 3], dtype="int8", elem_offset=0, align=128, offset_factor=1) + ethosu_write_2 = T.match_buffer(ethosu_write, [135], dtype="int8", elem_offset=0, align=128, offset_factor=1) # body - T.evaluate(T.call_extern("ethosu_binary_elementwise", "int8", 5, 9, 3, 5, 0, 9, T.load("int8", placeholder_2.data, 0), 0, 0, 0, T.float32(1.0), 0, "NHWC", 27, 3, 1, "int8", 5, 9, 3, 5, 0, 9, T.load("int8", placeholder_2.data, 135), 0, 0, 0, T.float32(1.0), 0, "NHWC", 27, 3, 1, "int8", 5, 9, 3, 5, 0, 9, T.load("int8", ethosu_write_2.data, 0), 0, 0, 0, T.float32(1.0), 0, "NHWC", 27, 3, 1, "MAX", 0, "CLIP", 10, 100, "TFL", dtype="int8")) + T.evaluate(T.call_extern("ethosu_binary_elementwise", "int8", 5, 9, 3, 5, 0, 9, placeholder_2[0], 0, 0, 0, T.float32(1.0), 0, "NHWC", 27, 3, 1, "int8", 5, 9, 3, 5, 0, 9, placeholder_2[135], 0, 0, 0, T.float32(1.0), 0, "NHWC", 27, 3, 1, "int8", 5, 9, 3, 5, 0, 9, ethosu_write_2[0], 0, 0, 0, T.float32(1.0), 0, "NHWC", 27, 3, 1, "MAX", 0, "CLIP", 10, 100, "TFL", dtype="int8")) __tvm_meta__ = None # fmt: on @@ -1146,9 +1146,9 @@ def main(placeholder: T.handle, ethosu_write: T.handle) -> None: # function attr dict T.func_attr({"global_symbol": "main", "tir.noalias": True}) placeholder_2 = T.match_buffer(placeholder, [270], dtype="int32", elem_offset=0, align=128, offset_factor=1) - ethosu_write_2 = T.match_buffer(ethosu_write, [1, 5, 9, 3], dtype="int32", elem_offset=0, align=128, offset_factor=1) + ethosu_write_2 = T.match_buffer(ethosu_write, [135], dtype="int32", elem_offset=0, align=128, offset_factor=1) # body - T.evaluate(T.call_extern("ethosu_binary_elementwise", "int32", 5, 9, 3, 5, 0, 9, T.load("int32", placeholder_2.data, 0), 0, 0, 0, T.float32(1.0), 0, "NHWC", 27, 3, 1, "int32", 5, 9, 3, 5, 0, 9, T.load("int32", placeholder_2.data, 135), 0, 0, 0, T.float32(1.0), 0, "NHWC", 27, 3, 1, "int32", 5, 9, 3, 5, 0, 9, T.load("int32", ethosu_write_2.data, 0), 0, 0, 0, T.float32(1.0), 0, "NHWC", 27, 3, 1, "SHR", 0, "NONE", 0, 0, "TFL", dtype="int32")) + T.evaluate(T.call_extern("ethosu_binary_elementwise", "int32", 5, 9, 3, 5, 0, 9, placeholder_2[0], 0, 0, 0, T.float32(1.0), 0, "NHWC", 27, 3, 1, "int32", 5, 9, 3, 5, 0, 9, placeholder_2[135], 0, 0, 0, T.float32(1.0), 0, "NHWC", 27, 3, 1, "int32", 5, 9, 3, 5, 0, 9, ethosu_write_2[0], 0, 0, 0, T.float32(1.0), 0, "NHWC", 27, 3, 1, "SHR", 0, "NONE", 0, 0, "TFL", dtype="int32")) __tvm_meta__ = None # fmt: on @@ -1162,9 +1162,9 @@ def main(placeholder: T.handle, ethosu_write: T.handle) -> None: # function attr dict T.func_attr({"global_symbol": "main", "tir.noalias": True}) placeholder_2 = T.match_buffer(placeholder, [270], dtype="int32", elem_offset=0, align=128, offset_factor=1) - ethosu_write_2 = T.match_buffer(ethosu_write, [1, 5, 9, 3], dtype="int32", elem_offset=0, align=128, offset_factor=1) + ethosu_write_2 = T.match_buffer(ethosu_write, [135], dtype="int32", elem_offset=0, align=128, offset_factor=1) # body - T.evaluate(T.call_extern("ethosu_binary_elementwise", "int32", 5, 9, 3, 5, 0, 9, T.load("int32", placeholder_2.data, 0), 0, 0, 0, T.float32(1.0), 0, "NHWC", 27, 3, 1, "int32", 5, 9, 3, 5, 0, 9, T.load("int32", placeholder_2.data, 135), 0, 0, 0, T.float32(1.0), 0, "NHWC", 27, 3, 1, "int32", 5, 9, 3, 5, 0, 9, T.load("int32", ethosu_write_2.data, 0), 0, 0, 0, T.float32(1.0), 0, "NHWC", 27, 3, 1, "SHL", 0, "CLIP", 10, 100, "TFL", dtype="int32")) + T.evaluate(T.call_extern("ethosu_binary_elementwise", "int32", 5, 9, 3, 5, 0, 9, placeholder_2[0], 0, 0, 0, T.float32(1.0), 0, "NHWC", 27, 3, 1, "int32", 5, 9, 3, 5, 0, 9, placeholder_2[135], 0, 0, 0, T.float32(1.0), 0, "NHWC", 27, 3, 1, "int32", 5, 9, 3, 5, 0, 9, ethosu_write_2[0], 0, 0, 0, T.float32(1.0), 0, "NHWC", 27, 3, 1, "SHL", 0, "CLIP", 10, 100, "TFL", dtype="int32")) __tvm_meta__ = None # fmt: on @@ -1283,9 +1283,9 @@ def main(placeholder: T.handle, ethosu_write: T.handle) -> None: # function attr dict T.func_attr({"global_symbol": "main", "tir.noalias": True}) placeholder_2 = T.match_buffer(placeholder, [27], dtype="int8", elem_offset=0, align=128, offset_factor=1) - ethosu_write_2 = T.match_buffer(ethosu_write, [1, 2, 3, 4], dtype="int8", elem_offset=0, align=128, offset_factor=1) + ethosu_write_2 = T.match_buffer(ethosu_write, [24], dtype="int8", elem_offset=0, align=128, offset_factor=1) # body - T.evaluate(T.call_extern("ethosu_binary_elementwise", "int8", 2, 3, 4, 2, 0, 3, T.load("int8", placeholder_2.data, 0), 0, 0, 0, T.float32(1.0), 0, "NHWC", 12, 4, 1, "int8", 1, 3, 1, 1, 0, 3, T.load("int8", placeholder_2.data, 0), 0, 0, 0, T.float32(1.0), 0, "NHWC", 1, 1, 1, "int8", 2, 3, 4, 2, 0, 3, T.load("int8", ethosu_write_2.data, 0), 0, 0, 0, T.float32(1.0), 0, "NHWC", 12, 4, 1, "ADD", 1, "CLIP", 10, 100, "TFL", dtype="int8")) + T.evaluate(T.call_extern("ethosu_binary_elementwise", "int8", 2, 3, 4, 2, 0, 3, placeholder_2[0], 0, 0, 0, T.float32(1.0), 0, "NHWC", 12, 4, 1, "int8", 1, 3, 1, 1, 0, 3, placeholder_2[0], 0, 0, 0, T.float32(1.0), 0, "NHWC", 1, 1, 1, "int8", 2, 3, 4, 2, 0, 3, ethosu_write_2[0], 0, 0, 0, T.float32(1.0), 0, "NHWC", 12, 4, 1, "ADD", 1, "CLIP", 10, 100, "TFL", dtype="int8")) __tvm_meta__ = None # fmt: on @@ -1298,9 +1298,9 @@ def main(placeholder: T.handle, ethosu_write: T.handle) -> None: # function attr dict T.func_attr({"global_symbol": "main", "tir.noalias": True}) placeholder_2 = T.match_buffer(placeholder, [27], dtype="int8", elem_offset=0, align=128, offset_factor=1) - ethosu_write_2 = T.match_buffer(ethosu_write, [1, 2, 3, 4], dtype="int8", elem_offset=0, align=128, offset_factor=1) + ethosu_write_2 = T.match_buffer(ethosu_write, [24], dtype="int8", elem_offset=0, align=128, offset_factor=1) # body - T.evaluate(T.call_extern("ethosu_binary_elementwise", "int8", 2, 3, 4, 2, 0, 3, T.load("int8", placeholder_2.data, 0), 0, 0, 0, T.float32(1.0), 0, "NHWC", 12, 4, 1, "int8", 1, 3, 1, 1, 0, 3, T.load("int8", placeholder_2.data, 0), 0, 0, 0, T.float32(1.0), 0, "NHWC", 1, 1, 1, "int8", 2, 3, 4, 2, 0, 3, T.load("int8", ethosu_write_2.data, 0), 0, 0, 0, T.float32(1.0), 0, "NHWC", 12, 4, 1, "SUB", 1, "CLIP", 10, 100, "TFL", dtype="int8")) + T.evaluate(T.call_extern("ethosu_binary_elementwise", "int8", 2, 3, 4, 2, 0, 3, placeholder_2[0], 0, 0, 0, T.float32(1.0), 0, "NHWC", 12, 4, 1, "int8", 1, 3, 1, 1, 0, 3, placeholder_2[0], 0, 0, 0, T.float32(1.0), 0, "NHWC", 1, 1, 1, "int8", 2, 3, 4, 2, 0, 3, ethosu_write_2[0], 0, 0, 0, T.float32(1.0), 0, "NHWC", 12, 4, 1, "SUB", 1, "CLIP", 10, 100, "TFL", dtype="int8")) __tvm_meta__ = None # fmt: on @@ -1313,9 +1313,9 @@ def main(placeholder: T.handle, ethosu_write: T.handle) -> None: # function attr dict T.func_attr({"global_symbol": "main", "tir.noalias": True}) placeholder_2 = T.match_buffer(placeholder, [27], dtype="int8", elem_offset=0, align=128, offset_factor=1) - ethosu_write_2 = T.match_buffer(ethosu_write, [1, 2, 3, 4], dtype="int8", elem_offset=0, align=128, offset_factor=1) + ethosu_write_2 = T.match_buffer(ethosu_write, [24], dtype="int8", elem_offset=0, align=128, offset_factor=1) # body - T.evaluate(T.call_extern("ethosu_binary_elementwise", "int8", 2, 3, 4, 2, 0, 3, T.load("int8", placeholder_2.data, 0), 0, 0, 0, T.float32(1.0), 0, "NHWC", 12, 4, 1, "int8", 1, 3, 1, 1, 0, 3, T.load("int8", placeholder_2.data, 0), 0, 0, 0, T.float32(1.0), 0, "NHWC", 1, 1, 1, "int8", 2, 3, 4, 2, 0, 3, T.load("int8", ethosu_write_2.data, 0), 0, 0, 0, T.float32(1.0), 0, "NHWC", 12, 4, 1, "MUL", 1, "CLIP", 10, 100, "TFL", dtype="int8")) + T.evaluate(T.call_extern("ethosu_binary_elementwise", "int8", 2, 3, 4, 2, 0, 3, placeholder_2[0], 0, 0, 0, T.float32(1.0), 0, "NHWC", 12, 4, 1, "int8", 1, 3, 1, 1, 0, 3, placeholder_2[0], 0, 0, 0, T.float32(1.0), 0, "NHWC", 1, 1, 1, "int8", 2, 3, 4, 2, 0, 3, ethosu_write_2[0], 0, 0, 0, T.float32(1.0), 0, "NHWC", 12, 4, 1, "MUL", 1, "CLIP", 10, 100, "TFL", dtype="int8")) __tvm_meta__ = None # fmt: on @@ -1329,9 +1329,9 @@ def main(placeholder: T.handle, ethosu_write: T.handle) -> None: # function attr dict T.func_attr({"global_symbol": "main", "tir.noalias": True}) placeholder_2 = T.match_buffer(placeholder, [27], dtype="int8", elem_offset=0, align=128, offset_factor=1) - ethosu_write_2 = T.match_buffer(ethosu_write, [1, 2, 3, 4], dtype="int8", elem_offset=0, align=128, offset_factor=1) + ethosu_write_2 = T.match_buffer(ethosu_write, [24], dtype="int8", elem_offset=0, align=128, offset_factor=1) # body - T.evaluate(T.call_extern("ethosu_binary_elementwise", "int8", 2, 3, 4, 2, 0, 3, T.load("int8", placeholder_2.data, 0), 0, 0, 0, T.float32(1.0), 0, "NHWC", 12, 4, 1, "int8", 1, 3, 1, 1, 0, 3, T.load("int8", placeholder_2.data, 0), 0, 0, 0, T.float32(1.0), 0, "NHWC", 1, 1, 1, "int8", 2, 3, 4, 2, 0, 3, T.load("int8", ethosu_write_2.data, 0), 0, 0, 0, T.float32(1.0), 0, "NHWC", 12, 4, 1, "MIN", 1, "CLIP", 10, 100, "TFL", dtype="int8")) + T.evaluate(T.call_extern("ethosu_binary_elementwise", "int8", 2, 3, 4, 2, 0, 3, placeholder_2[0], 0, 0, 0, T.float32(1.0), 0, "NHWC", 12, 4, 1, "int8", 1, 3, 1, 1, 0, 3, placeholder_2[0], 0, 0, 0, T.float32(1.0), 0, "NHWC", 1, 1, 1, "int8", 2, 3, 4, 2, 0, 3, ethosu_write_2[0], 0, 0, 0, T.float32(1.0), 0, "NHWC", 12, 4, 1, "MIN", 1, "CLIP", 10, 100, "TFL", dtype="int8")) __tvm_meta__ = None # fmt: on @@ -1345,9 +1345,9 @@ def main(placeholder: T.handle, ethosu_write: T.handle) -> None: # function attr dict T.func_attr({"global_symbol": "main", "tir.noalias": True}) placeholder_2 = T.match_buffer(placeholder, [27], dtype="int8", elem_offset=0, align=128, offset_factor=1) - ethosu_write_2 = T.match_buffer(ethosu_write, [1, 2, 3, 4], dtype="int8", elem_offset=0, align=128, offset_factor=1) + ethosu_write_2 = T.match_buffer(ethosu_write, [24], dtype="int8", elem_offset=0, align=128, offset_factor=1) # body - T.evaluate(T.call_extern("ethosu_binary_elementwise", "int8", 2, 3, 4, 2, 0, 3, T.load("int8", placeholder_2.data, 0), 0, 0, 0, T.float32(1.0), 0, "NHWC", 12, 4, 1, "int8", 1, 3, 1, 1, 0, 3, T.load("int8", placeholder_2.data, 0), 0, 0, 0, T.float32(1.0), 0, "NHWC", 1, 1, 1, "int8", 2, 3, 4, 2, 0, 3, T.load("int8", ethosu_write_2.data, 0), 0, 0, 0, T.float32(1.0), 0, "NHWC", 12, 4, 1, "MAX", 1, "CLIP", 10, 100, "TFL", dtype="int8")) + T.evaluate(T.call_extern("ethosu_binary_elementwise", "int8", 2, 3, 4, 2, 0, 3, placeholder_2[0], 0, 0, 0, T.float32(1.0), 0, "NHWC", 12, 4, 1, "int8", 1, 3, 1, 1, 0, 3, placeholder_2[0], 0, 0, 0, T.float32(1.0), 0, "NHWC", 1, 1, 1, "int8", 2, 3, 4, 2, 0, 3, ethosu_write_2[0], 0, 0, 0, T.float32(1.0), 0, "NHWC", 12, 4, 1, "MAX", 1, "CLIP", 10, 100, "TFL", dtype="int8")) __tvm_meta__ = None # fmt: on @@ -1361,9 +1361,9 @@ def main(placeholder: T.handle, ethosu_write: T.handle) -> None: # function attr dict T.func_attr({"global_symbol": "main", "tir.noalias": True}) placeholder_2 = T.match_buffer(placeholder, [27], dtype="int32", elem_offset=0, align=128, offset_factor=1) - ethosu_write_2 = T.match_buffer(ethosu_write, [1, 2, 3, 4], dtype="int32", elem_offset=0, align=128, offset_factor=1) + ethosu_write_2 = T.match_buffer(ethosu_write, [24], dtype="int32", elem_offset=0, align=128, offset_factor=1) # body - T.evaluate(T.call_extern("ethosu_binary_elementwise", "int32", 2, 3, 4, 2, 0, 3, T.load("int32", placeholder_2.data, 0), 0, 0, 0, T.float32(1.0), 0, "NHWC", 12, 4, 1, "int32", 1, 3, 1, 1, 0, 3, T.load("int32", placeholder_2.data, 0), 0, 0, 0, T.float32(1.0), 0, "NHWC", 1, 1, 1, "int32", 2, 3, 4, 2, 0, 3, T.load("int32", ethosu_write_2.data, 0), 0, 0, 0, T.float32(1.0), 0, "NHWC", 12, 4, 1, "SHR", 1, "NONE", 0, 0, "TFL", dtype="int32")) + T.evaluate(T.call_extern("ethosu_binary_elementwise", "int32", 2, 3, 4, 2, 0, 3, placeholder_2[0], 0, 0, 0, T.float32(1.0), 0, "NHWC", 12, 4, 1, "int32", 1, 3, 1, 1, 0, 3, placeholder_2[0], 0, 0, 0, T.float32(1.0), 0, "NHWC", 1, 1, 1, "int32", 2, 3, 4, 2, 0, 3, ethosu_write_2[0], 0, 0, 0, T.float32(1.0), 0, "NHWC", 12, 4, 1, "SHR", 1, "NONE", 0, 0, "TFL", dtype="int32")) __tvm_meta__ = None # fmt: on @@ -1377,9 +1377,9 @@ def main(placeholder: T.handle, ethosu_write: T.handle) -> None: # function attr dict T.func_attr({"global_symbol": "main", "tir.noalias": True}) placeholder_2 = T.match_buffer(placeholder, [27], dtype="int32", elem_offset=0, align=128, offset_factor=1) - ethosu_write_2 = T.match_buffer(ethosu_write, [1, 2, 3, 4], dtype="int32", elem_offset=0, align=128, offset_factor=1) + ethosu_write_2 = T.match_buffer(ethosu_write, [24], dtype="int32", elem_offset=0, align=128, offset_factor=1) # body - T.evaluate(T.call_extern("ethosu_binary_elementwise", "int32", 2, 3, 4, 2, 0, 3, T.load("int32", placeholder_2.data, 0), 0, 0, 0, T.float32(1.0), 0, "NHWC", 12, 4, 1, "int32", 1, 3, 1, 1, 0, 3, T.load("int32", placeholder_2.data, 0), 0, 0, 0, T.float32(1.0), 0, "NHWC", 1, 1, 1, "int32", 2, 3, 4, 2, 0, 3, T.load("int32", ethosu_write_2.data, 0), 0, 0, 0, T.float32(1.0), 0, "NHWC", 12, 4, 1, "SHL", 1, "CLIP", 10, 100, "TFL", dtype="int32")) + T.evaluate(T.call_extern("ethosu_binary_elementwise", "int32", 2, 3, 4, 2, 0, 3, placeholder_2[0], 0, 0, 0, T.float32(1.0), 0, "NHWC", 12, 4, 1, "int32", 1, 3, 1, 1, 0, 3, placeholder_2[0], 0, 0, 0, T.float32(1.0), 0, "NHWC", 1, 1, 1, "int32", 2, 3, 4, 2, 0, 3, ethosu_write_2[0], 0, 0, 0, T.float32(1.0), 0, "NHWC", 12, 4, 1, "SHL", 1, "CLIP", 10, 100, "TFL", dtype="int32")) __tvm_meta__ = None # fmt: on diff --git a/tests/python/contrib/test_ethosu/test_vela_api.py b/tests/python/contrib/test_ethosu/test_vela_api.py index af75dc82a0bb..662b35822cc2 100644 --- a/tests/python/contrib/test_ethosu/test_vela_api.py +++ b/tests/python/contrib/test_ethosu/test_vela_api.py @@ -50,7 +50,7 @@ def main( # function attr dict T.func_attr({"global_symbol": "main", "tir.noalias": True}) placeholder_3 = T.match_buffer( - placeholder, [1, 8, 8, 3], dtype="uint8", elem_offset=0, align=128, offset_factor=1 + placeholder, [192], dtype="uint8", elem_offset=0, align=128, offset_factor=1 ) placeholder_4 = T.match_buffer( placeholder_1, [48], dtype="uint8", elem_offset=0, align=128, offset_factor=1 @@ -59,7 +59,7 @@ def main( placeholder_2, [16], dtype="int32", elem_offset=0, align=128, offset_factor=1 ) ethosu_conv2d_1 = T.match_buffer( - ethosu_conv2d, [1, 8, 8, 16], dtype="uint8", elem_offset=0, align=128, offset_factor=1 + ethosu_conv2d, [1024], dtype="uint8", elem_offset=0, align=128, offset_factor=1 ) # body T.evaluate( @@ -72,7 +72,7 @@ def main( 8, 0, 8, - T.load("uint8", placeholder_3.data, 0), + placeholder_3[0], 0, 0, 0, @@ -89,7 +89,7 @@ def main( 8, 0, 8, - T.load("uint8", ethosu_conv2d_1.data, 0), + ethosu_conv2d_1[0], 0, 0, 0, @@ -105,10 +105,10 @@ def main( 1, 1, 1, - T.load("uint8", placeholder_4.data, 0), + placeholder_4[0], 0, 12, - T.load("uint8", placeholder_5.data, 0), + placeholder_5[0], 0, 0, 0, @@ -142,10 +142,10 @@ def main( # function attr dict T.func_attr({"global_symbol": "main", "tir.noalias": True}) placeholder_3 = T.match_buffer( - placeholder, [1, 8, 8, 3], dtype="uint8", elem_offset=0, align=128, offset_factor=1 + placeholder, [192], dtype="uint8", elem_offset=0, align=128, offset_factor=1 ) placeholder_4 = T.match_buffer( - placeholder_1, [16, 1, 1, 3], dtype="uint8", elem_offset=0, align=128, offset_factor=1 + placeholder_1, [48], dtype="uint8", elem_offset=0, align=128, offset_factor=1 ) placeholder_5 = T.match_buffer( placeholder_2, [16], dtype="int32", elem_offset=0, align=128, offset_factor=1 @@ -155,7 +155,7 @@ def main( placeholder_6, [16], dtype="float32", elem_offset=0, align=128, offset_factor=1 ) ethosu_conv2d_1 = T.match_buffer( - ethosu_conv2d, [1, 8, 8, 16], dtype="uint8", elem_offset=0, align=128, offset_factor=1 + ethosu_conv2d, [1024], dtype="uint8", elem_offset=0, align=128, offset_factor=1 ) # body T.evaluate( @@ -168,7 +168,7 @@ def main( 8, 0, 8, - T.load("uint8", placeholder_3.data, 0), + placeholder_3[0], 0, 0, 0, @@ -185,7 +185,7 @@ def main( 8, 0, 8, - T.load("uint8", ethosu_conv2d_1.data, 0), + ethosu_conv2d_1[0], 0, 0, 0, @@ -201,10 +201,10 @@ def main( 1, 1, 1, - T.load("uint8", placeholder_4.data, 0), + placeholder_4[0], 0, 12, - T.load("uint8", placeholder_5.data, 0), + placeholder_5[0], 0, 0, 0, diff --git a/tests/python/relay/aot/test_crt_aot.py b/tests/python/relay/aot/test_crt_aot.py index 2ce36f19fcc8..a3004904b119 100644 --- a/tests/python/relay/aot/test_crt_aot.py +++ b/tests/python/relay/aot/test_crt_aot.py @@ -766,24 +766,27 @@ class Model(tf.Module): def tf_function(self, x): # Use tf.nn API to create the model tf_strides = [1, strides[0], strides[1], 1] + filter_shape = [kernel_shape[0], kernel_shape[1], 3, 3] + filter1 = tf.constant( + np.arange(np.prod(filter_shape)).reshape(filter_shape), + dtype=tf.float32, + ) op = tf.nn.conv2d( x, - filters=tf.constant( - np.random.uniform(size=[kernel_shape[0], kernel_shape[1], 3, 3]), - dtype=tf.float32, - ), + filters=filter1, strides=tf_strides, padding=padding, dilations=dilation, ) op = tf.nn.relu(op) # Second convolution + filter2 = tf.constant( + 1000 + np.arange(np.prod(filter_shape)).reshape(filter_shape), + dtype=tf.float32, + ) op2 = tf.nn.conv2d( x, - filters=tf.constant( - np.random.uniform(size=(kernel_shape[0], kernel_shape[1], 3, 3)), - dtype=tf.float32, - ), + filters=filter2, strides=strides, padding=padding, data_format="NHWC", diff --git a/tests/python/unittest/test_lower_build.py b/tests/python/unittest/test_lower_build.py index 9848aed4b51c..bd820b617c2d 100644 --- a/tests/python/unittest/test_lower_build.py +++ b/tests/python/unittest/test_lower_build.py @@ -53,37 +53,41 @@ def matmul(a: T.handle, b: T.handle, c: T.handle) -> None: @tvm.script.ir_module class LoweredModule: @T.prim_func - def main(a: T.handle, b: T.handle, c: T.handle) -> None: + def main( + A: T.Buffer[(16384,), "float32"], + B: T.Buffer[(16384,), "float32"], + C: T.Buffer[(16384,), "float32"], + ) -> None: # function attr dict T.func_attr({"global_symbol": "main", "from_legacy_te_schedule": True, "tir.noalias": True}) - A = T.match_buffer(a, [128, 128]) - B = T.match_buffer(b, [128, 128]) - C = T.match_buffer(c, [128, 128]) + T.preflattened_buffer(A, [128, 128], data=A.data) + T.preflattened_buffer(B, [128, 128], data=B.data) + T.preflattened_buffer(C, [128, 128], data=C.data) # body for x, y in T.grid(128, 128): - C.data[x * 128 + y] = 0.0 + C[x * 128 + y] = 0.0 for k in T.serial(0, 128): - C.data[x * 128 + y] = T.load("float32", C.data, x * 128 + y) + T.load( - "float32", A.data, x * 128 + k - ) * T.load("float32", B.data, y * 128 + k) + C[x * 128 + y] = C[x * 128 + y] + A[x * 128 + k] * B[y * 128 + k] @tvm.script.ir_module class LoweredTIRModule: @T.prim_func - def main(a: T.handle, b: T.handle, c: T.handle) -> None: + def main( + A: T.Buffer[(16384,), "float32"], + B: T.Buffer[(16384,), "float32"], + C: T.Buffer[(16384,), "float32"], + ) -> None: # function attr dict T.func_attr({"global_symbol": "main", "tir.noalias": True}) - A = T.match_buffer(a, [128, 128]) - B = T.match_buffer(b, [128, 128]) - C = T.match_buffer(c, [128, 128]) + T.preflattened_buffer(A, [128, 128], data=A.data) + T.preflattened_buffer(B, [128, 128], data=B.data) + T.preflattened_buffer(C, [128, 128], data=C.data) # body for x, y in T.grid(128, 128): - C.data[x * 128 + y] = 0.0 + C[x * 128 + y] = 0.0 for k in T.serial(0, 128): - C.data[x * 128 + y] = T.load("float32", C.data, x * 128 + y) + T.load( - "float32", A.data, x * 128 + k - ) * T.load("float32", B.data, y * 128 + k) + C[x * 128 + y] = C[x * 128 + y] + A[x * 128 + k] * B[y * 128 + k] def test_lower_build_te_schedule(): diff --git a/tests/python/unittest/test_meta_schedule_postproc_verify_gpu_code.py b/tests/python/unittest/test_meta_schedule_postproc_verify_gpu_code.py index db302f4b7e4d..041f641a45d2 100644 --- a/tests/python/unittest/test_meta_schedule_postproc_verify_gpu_code.py +++ b/tests/python/unittest/test_meta_schedule_postproc_verify_gpu_code.py @@ -60,8 +60,8 @@ def main(a: T.handle, b: T.handle) -> None: blockIdx_x = T.env_thread("blockIdx.x") blockIdx_y = T.env_thread("blockIdx.y") blockIdx_z = T.env_thread("blockIdx.z") - A = T.match_buffer(a, [14, 14, 256, 256], dtype="float32") - B = T.match_buffer(b, [14, 14, 512, 256], dtype="float32") + A = T.match_buffer(a, [14*14*256*256], dtype="float32") + B = T.match_buffer(b, [14*14*512*256], dtype="float32") # body T.launch_thread(blockIdx_z, 196) B_local = T.allocate([64], "float32", "local") @@ -72,17 +72,22 @@ def main(a: T.handle, b: T.handle) -> None: T.launch_thread(threadIdx_y, 8) T.launch_thread(threadIdx_x, 8) for ff_c_init, nn_c_init in T.grid(8, 8): - T.store(B_local, ff_c_init * 8 + nn_c_init, T.float32(0), True) + B_local[ff_c_init * 8 + nn_c_init] = T.float32(0) for rc_outer, ry, rx in T.grid(32, 3, 3): for ax3_inner_outer in T.serial(0, 2): - T.store(Apad_shared, T.ramp(threadIdx_y * 64 + threadIdx_x * 8 + ax3_inner_outer * 4, 1, 4), T.if_then_else(1 <= blockIdx_z // 14 + ry and blockIdx_z // 14 + ry < 15 and 1 <= rx + blockIdx_z % 14 and rx + blockIdx_z % 14 < 15, T.load("float32x4", A.data, T.ramp(ry * 917504 + blockIdx_z * 65536 + rx * 65536 + rc_outer * 2048 + threadIdx_y * 256 + blockIdx_x * 64 + threadIdx_x * 8 + ax3_inner_outer * 4 - 983040, 1, 4), T.broadcast(True, 4)), T.broadcast(T.float32(0), 4), dtype="float32x4"), T.broadcast(True, 4)) + Apad_shared[T.ramp(threadIdx_y * 64 + threadIdx_x * 8 + ax3_inner_outer * 4, 1, 4)] = T.if_then_else( + 1 <= blockIdx_z // 14 + ry and blockIdx_z // 14 + ry < 15 and 1 <= rx + blockIdx_z % 14 and rx + blockIdx_z % 14 < 15, + A[T.ramp(ry * 917504 + blockIdx_z * 65536 + rx * 65536 + rc_outer * 2048 + threadIdx_y * 256 + blockIdx_x * 64 + threadIdx_x * 8 + ax3_inner_outer * 4 - 983040, 1, 4)], + T.broadcast(T.float32(0), 4), + dtype="float32x4", + ) for rc_inner in T.serial(0, 8): for ax3 in T.serial(0, 8): - T.store(Apad_shared_local, ax3, T.load("float32", Apad_shared, rc_inner * 64 + threadIdx_x * 8 + ax3), True) + Apad_shared_local[ax3] = Apad_shared[rc_inner * 64 + threadIdx_x * 8 + ax3] for ff_c, nn_c in T.grid(8, 8): - T.store(B_local, ff_c * 8 + nn_c, T.load("float32", B_local, ff_c * 8 + nn_c) + T.load("float32", Apad_shared_local, nn_c), True) + B_local[ff_c * 8 + nn_c] = B_local[ff_c * 8 + nn_c] + Apad_shared_local[nn_c] for ff_inner_inner_inner, nn_inner_inner_inner in T.grid(8, 8): - T.store(B.data, blockIdx_z * 131072 + blockIdx_y * 16384 + threadIdx_y * 2048 + ff_inner_inner_inner * 256 + blockIdx_x * 64 + threadIdx_x * 8 + nn_inner_inner_inner, T.load("float32", B_local, ff_inner_inner_inner * 8 + nn_inner_inner_inner), True)# fmt: on + B[blockIdx_z * 131072 + blockIdx_y * 16384 + threadIdx_y * 2048 + ff_inner_inner_inner * 256 + blockIdx_x * 64 + threadIdx_x * 8 + nn_inner_inner_inner] = B_local[ff_inner_inner_inner * 8 + nn_inner_inner_inner] # fmt: on @tvm.script.ir_module @@ -97,8 +102,8 @@ def main(a: T.handle, b: T.handle) -> None: blockIdx_x = T.env_thread("blockIdx.x") blockIdx_y = T.env_thread("blockIdx.y") blockIdx_z = T.env_thread("blockIdx.z") - A = T.match_buffer(a, [14, 14, 256, 256], dtype="float32") - B = T.match_buffer(b, [14, 14, 512, 256], dtype="float32") + A = T.match_buffer(a, [14*14*256*256], dtype="float32") + B = T.match_buffer(b, [14*14*512*256], dtype="float32") # body T.launch_thread(blockIdx_z, 196) B_local = T.allocate([6400000], "float32", "local") @@ -109,17 +114,26 @@ def main(a: T.handle, b: T.handle) -> None: T.launch_thread(threadIdx_y, 8) T.launch_thread(threadIdx_x, 8) for ff_c_init, nn_c_init in T.grid(8, 8): - T.store(B_local, ff_c_init * 8 + nn_c_init, T.float32(0), True) + B_local[ff_c_init * 8 + nn_c_init] = T.float32(0) + # Access of the last element of B_local prevents buffer + # compacting from reducing the amount of shared memory + # used. + B_local[6400000-1 + ff_c_init*8] = 0.0 for rc_outer, ry, rx in T.grid(32, 3, 3): for ax3_inner_outer in T.serial(0, 2): - T.store(Apad_shared, T.ramp(threadIdx_y * 64 + threadIdx_x * 8 + ax3_inner_outer * 4, 1, 4), T.if_then_else(1 <= blockIdx_z // 14 + ry and blockIdx_z // 14 + ry < 15 and 1 <= rx + blockIdx_z % 14 and rx + blockIdx_z % 14 < 15, T.load("float32x4", A.data, T.ramp(ry * 917504 + blockIdx_z * 65536 + rx * 65536 + rc_outer * 2048 + threadIdx_y * 256 + blockIdx_x * 64 + threadIdx_x * 8 + ax3_inner_outer * 4 - 983040, 1, 4), T.broadcast(True, 4)), T.broadcast(T.float32(0), 4), dtype="float32x4"), T.broadcast(True, 4)) + Apad_shared[T.ramp(threadIdx_y * 64 + threadIdx_x * 8 + ax3_inner_outer * 4, 1, 4)] = T.if_then_else( + 1 <= blockIdx_z // 14 + ry and blockIdx_z // 14 + ry < 15 and 1 <= rx + blockIdx_z % 14 and rx + blockIdx_z % 14 < 15, + A[T.ramp(ry * 917504 + blockIdx_z * 65536 + rx * 65536 + rc_outer * 2048 + threadIdx_y * 256 + blockIdx_x * 64 + threadIdx_x * 8 + ax3_inner_outer * 4 - 983040, 1, 4)], + T.broadcast(T.float32(0), 4), + dtype="float32x4", + ) for rc_inner in T.serial(0, 8): for ax3 in T.serial(0, 8): - T.store(Apad_shared_local, ax3, T.load("float32", Apad_shared, rc_inner * 64 + threadIdx_x * 8 + ax3), True) + Apad_shared_local[ax3] = Apad_shared[rc_inner * 64 + threadIdx_x * 8 + ax3] for ff_c, nn_c in T.grid(8, 8): - T.store(B_local, ff_c * 8 + nn_c, T.load("float32", B_local, ff_c * 8 + nn_c) + T.load("float32", Apad_shared_local, nn_c), True) + B_local[ff_c * 8 + nn_c] = B_local[ff_c * 8 + nn_c] + Apad_shared_local[nn_c] for ff_inner_inner_inner, nn_inner_inner_inner in T.grid(8, 8): - T.store(B.data, blockIdx_z * 131072 + blockIdx_y * 16384 + threadIdx_y * 2048 + ff_inner_inner_inner * 256 + blockIdx_x * 64 + threadIdx_x * 8 + nn_inner_inner_inner, T.load("float32", B_local, ff_inner_inner_inner * 8 + nn_inner_inner_inner), True)# fmt: on + B[blockIdx_z * 131072 + blockIdx_y * 16384 + threadIdx_y * 2048 + ff_inner_inner_inner * 256 + blockIdx_x * 64 + threadIdx_x * 8 + nn_inner_inner_inner] = B_local[ff_inner_inner_inner * 8 + nn_inner_inner_inner]# fmt: on @tvm.script.ir_module @@ -134,8 +148,8 @@ def main(a: T.handle, b: T.handle) -> None: blockIdx_x = T.env_thread("blockIdx.x") blockIdx_y = T.env_thread("blockIdx.y") blockIdx_z = T.env_thread("blockIdx.z") - A = T.match_buffer(a, [14, 14, 256, 256], dtype="float32") - B = T.match_buffer(b, [14, 14, 512, 256], dtype="float32") + A = T.match_buffer(a, [14*14*256*256], dtype="float32") + B = T.match_buffer(b, [14*14*512*256], dtype="float32") # body T.launch_thread(blockIdx_z, 196) B_local = T.allocate([64], "float32", "local") @@ -146,17 +160,26 @@ def main(a: T.handle, b: T.handle) -> None: T.launch_thread(threadIdx_y, 8) T.launch_thread(threadIdx_x, 8) for ff_c_init, nn_c_init in T.grid(8, 8): - T.store(B_local, ff_c_init * 8 + nn_c_init, T.float32(0), True) + B_local[ff_c_init * 8 + nn_c_init] = T.float32(0) for rc_outer, ry, rx in T.grid(32, 3, 3): for ax3_inner_outer in T.serial(0, 2): - T.store(Apad_shared, T.ramp(threadIdx_y * 64 + threadIdx_x * 8 + ax3_inner_outer * 4, 1, 4), T.if_then_else(1 <= blockIdx_z // 14 + ry and blockIdx_z // 14 + ry < 15 and 1 <= rx + blockIdx_z % 14 and rx + blockIdx_z % 14 < 15, T.load("float32x4", A.data, T.ramp(ry * 917504 + blockIdx_z * 65536 + rx * 65536 + rc_outer * 2048 + threadIdx_y * 256 + blockIdx_x * 64 + threadIdx_x * 8 + ax3_inner_outer * 4 - 983040, 1, 4), T.broadcast(True, 4)), T.broadcast(T.float32(0), 4), dtype="float32x4"), T.broadcast(True, 4)) + Apad_shared[T.ramp(threadIdx_y * 64 + threadIdx_x * 8 + ax3_inner_outer * 4, 1, 4)] = T.if_then_else( + 1 <= blockIdx_z // 14 + ry and blockIdx_z // 14 + ry < 15 and 1 <= rx + blockIdx_z % 14 and rx + blockIdx_z % 14 < 15, + A[T.ramp(ry * 917504 + blockIdx_z * 65536 + rx * 65536 + rc_outer * 2048 + threadIdx_y * 256 + blockIdx_x * 64 + threadIdx_x * 8 + ax3_inner_outer * 4 - 983040, 1, 4)], + T.broadcast(T.float32(0), 4), + dtype="float32x4", + ) + # Access of the last element of Apad_shared prevents + # buffer compacting from reducing the amount of shared + # memory used. + Apad_shared[512000-1] = 0.0 for rc_inner in T.serial(0, 8): for ax3 in T.serial(0, 8): - T.store(Apad_shared_local, ax3, T.load("float32", Apad_shared, rc_inner * 64 + threadIdx_x * 8 + ax3), True) + Apad_shared_local[ax3] = Apad_shared[rc_inner * 64 + threadIdx_x * 8 + ax3] for ff_c, nn_c in T.grid(8, 8): - T.store(B_local, ff_c * 8 + nn_c, T.load("float32", B_local, ff_c * 8 + nn_c) + T.load("float32", Apad_shared_local, nn_c), True) + B_local[ff_c * 8 + nn_c] = B_local[ff_c * 8 + nn_c] + Apad_shared_local[nn_c] for ff_inner_inner_inner, nn_inner_inner_inner in T.grid(8, 8): - T.store(B.data, blockIdx_z * 131072 + blockIdx_y * 16384 + threadIdx_y * 2048 + ff_inner_inner_inner * 256 + blockIdx_x * 64 + threadIdx_x * 8 + nn_inner_inner_inner, T.load("float32", B_local, ff_inner_inner_inner * 8 + nn_inner_inner_inner), True)# fmt: on + B[blockIdx_z * 131072 + blockIdx_y * 16384 + threadIdx_y * 2048 + ff_inner_inner_inner * 256 + blockIdx_x * 64 + threadIdx_x * 8 + nn_inner_inner_inner] = B_local[ff_inner_inner_inner * 8 + nn_inner_inner_inner]# fmt: on @tvm.script.ir_module @@ -171,8 +194,8 @@ def main(a: T.handle, b: T.handle) -> None: blockIdx_x = T.env_thread("blockIdx.x") blockIdx_y = T.env_thread("blockIdx.y") blockIdx_z = T.env_thread("blockIdx.z") - A = T.match_buffer(a, [14, 14, 256, 256], dtype="float32") - B = T.match_buffer(b, [14, 14, 512, 256], dtype="float32") + A = T.match_buffer(a, [14*14*256*256], dtype="float32") + B = T.match_buffer(b, [14*14*512*256], dtype="float32") # body T.launch_thread(blockIdx_z, 196) B_local = T.allocate([64], "float32", "local") @@ -183,17 +206,22 @@ def main(a: T.handle, b: T.handle) -> None: T.launch_thread(threadIdx_y, 8) T.launch_thread(threadIdx_x, 800000) for ff_c_init, nn_c_init in T.grid(8, 8): - T.store(B_local, ff_c_init * 8 + nn_c_init, T.float32(0), True) + B_local[ff_c_init * 8 + nn_c_init] = T.float32(0) for rc_outer, ry, rx in T.grid(32, 3, 3): for ax3_inner_outer in T.serial(0, 2): - T.store(Apad_shared, T.ramp(threadIdx_y * 64 + threadIdx_x * 8 + ax3_inner_outer * 4, 1, 4), T.if_then_else(1 <= blockIdx_z // 14 + ry and blockIdx_z // 14 + ry < 15 and 1 <= rx + blockIdx_z % 14 and rx + blockIdx_z % 14 < 15, T.load("float32x4", A.data, T.ramp(ry * 917504 + blockIdx_z * 65536 + rx * 65536 + rc_outer * 2048 + threadIdx_y * 256 + blockIdx_x * 64 + threadIdx_x * 8 + ax3_inner_outer * 4 - 983040, 1, 4), T.broadcast(True, 4)), T.broadcast(T.float32(0), 4), dtype="float32x4"), T.broadcast(True, 4)) + Apad_shared[T.ramp(threadIdx_y * 64 + threadIdx_x * 8 + ax3_inner_outer * 4, 1, 4)] = T.if_then_else( + 1 <= blockIdx_z // 14 + ry and blockIdx_z // 14 + ry < 15 and 1 <= rx + blockIdx_z % 14 and rx + blockIdx_z % 14 < 15, + A[T.ramp(ry * 917504 + blockIdx_z * 65536 + rx * 65536 + rc_outer * 2048 + threadIdx_y * 256 + blockIdx_x * 64 + threadIdx_x * 8 + ax3_inner_outer * 4 - 983040, 1, 4)], + T.broadcast(T.float32(0), 4), + dtype="float32x4", + ) for rc_inner in T.serial(0, 8): for ax3 in T.serial(0, 8): - T.store(Apad_shared_local, ax3, T.load("float32", Apad_shared, rc_inner * 64 + threadIdx_x * 8 + ax3), True) + Apad_shared_local[ax3] = Apad_shared[rc_inner * 64 + threadIdx_x * 8 + ax3] for ff_c, nn_c in T.grid(8, 8): - T.store(B_local, ff_c * 8 + nn_c, T.load("float32", B_local, ff_c * 8 + nn_c) + T.load("float32", Apad_shared_local, nn_c), True) + B_local[ff_c * 8 + nn_c] = B_local[ff_c * 8 + nn_c] + Apad_shared_local[nn_c] for ff_inner_inner_inner, nn_inner_inner_inner in T.grid(8, 8): - T.store(B.data, blockIdx_z * 131072 + blockIdx_y * 16384 + threadIdx_y * 2048 + ff_inner_inner_inner * 256 + blockIdx_x * 64 + threadIdx_x * 8 + nn_inner_inner_inner, T.load("float32", B_local, ff_inner_inner_inner * 8 + nn_inner_inner_inner), True)# fmt: on + B[blockIdx_z * 131072 + blockIdx_y * 16384 + threadIdx_y * 2048 + ff_inner_inner_inner * 256 + blockIdx_x * 64 + threadIdx_x * 8 + nn_inner_inner_inner] = B_local[ff_inner_inner_inner * 8 + nn_inner_inner_inner]# fmt: on @T.prim_func def GmmCuda0(X: T.Buffer[(1, 128, 128), "float32"], Y: T.Buffer[(1, 128, 128), "float32"], Z: T.Buffer[(1, 128, 128), "float32"]) -> None: @@ -388,6 +416,8 @@ def test_postproc_verify_gpu_2(): mod = Conv2dCuda2 ctx = _create_context(mod, target=_target()) sch = tir.Schedule(mod, debug_mask="all") + # Should fail due to too much local memory per block (large + # Apad_shared allocation). assert not ctx.postprocs[0].apply(sch) @@ -395,6 +425,8 @@ def test_postproc_verify_gpu_3(): mod = Conv2dCuda3 ctx = _create_context(mod, target=_target()) sch = tir.Schedule(mod, debug_mask="all") + # Should fail due to too many threads per block (large + # threadIdx.x extent). assert not ctx.postprocs[0].apply(sch) diff --git a/tests/python/unittest/test_runtime_module_based_interface.py b/tests/python/unittest/test_runtime_module_based_interface.py index 5788e4abc0fe..c63dd87b4e36 100644 --- a/tests/python/unittest/test_runtime_module_based_interface.py +++ b/tests/python/unittest/test_runtime_module_based_interface.py @@ -656,7 +656,7 @@ def make_func(symbol): 0, n - 1, tvm.tir.ForKind.SERIAL, - tvm.tir.Store(Ab.data, tvm.tir.Load("float32", Ab.data, i) + 1, i + 1), + tvm.tir.BufferStore(Ab, tvm.tir.BufferLoad(Ab, [i]) + 1, [i + 1]), ) return tvm.tir.PrimFunc([Ab], stmt).with_attr("global_symbol", symbol) diff --git a/tests/python/unittest/test_runtime_module_load.py b/tests/python/unittest/test_runtime_module_load.py index f17a615ce2c1..9d067630879a 100644 --- a/tests/python/unittest/test_runtime_module_load.py +++ b/tests/python/unittest/test_runtime_module_load.py @@ -59,7 +59,7 @@ def save_object(names): 0, n - 1, tvm.tir.ForKind.SERIAL, - tvm.tir.Store(Ab.data, tvm.tir.Load(dtype, Ab.data, i) + 1, i + 1), + tvm.tir.BufferStore(Ab, tvm.tir.BufferLoad(Ab, [i]) + 1, [i + 1]), ) mod = tvm.IRModule.from_expr( tvm.tir.PrimFunc([Ab], stmt).with_attr("global_symbol", "main") diff --git a/tests/python/unittest/test_target_codegen_cuda.py b/tests/python/unittest/test_target_codegen_cuda.py index 305f82558edc..994a85095728 100644 --- a/tests/python/unittest/test_target_codegen_cuda.py +++ b/tests/python/unittest/test_target_codegen_cuda.py @@ -14,6 +14,8 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +import re + import tvm from tvm import te import numpy as np @@ -275,22 +277,16 @@ def test_cuda_shuffle(): def MyVectorize(): def vectorizer(op): if op.kind == tvm.tir.ForKind.VECTORIZED: - four = tvm.tir.const(4, "int32") - idx = tvm.tir.Ramp(thrx.var * four, tvm.tir.const(1, "int32"), 4) - all_ones = tvm.tir.const(1, "int32x4") + idx = tvm.tir.Ramp(4 * thrx.var, 1, 4) store = op.body value = store.value - new_a = tvm.tir.Load("int32x4", value.a.buffer_var, idx, all_ones) + new_a = tvm.tir.BufferLoad(value.a.buffer, [idx]) bs, ids = [], [] for i in range(4): - bs.append( - tvm.tir.Load( - "int32", value.b.buffer_var, thrx.var * four + tvm.tir.const(i, "int32") - ) - ) - ids.append(tvm.tir.const(3 - i, "int32")) + bs.append(tvm.tir.BufferLoad(value.b.buffer, [4 * thrx.var + i])) + ids.append(3 - i) new_b = tvm.tir.Shuffle(bs, ids) - return tvm.tir.Store(store.buffer_var, new_a + new_b, idx, all_ones) + return tvm.tir.BufferStore(store.buffer, new_a + new_b, [idx]) return None def _transform(f, *_): @@ -808,23 +804,27 @@ def vcf_check_common(s, args): inside_broadcast = [False] # Possible patterns: - # Reduce init: Store[Ramp] = Broadcast(0) - # Shared memory copy: Store[Ramp] = Load[Ramp] - # Compute: Store[Ramp] = Load[Ramp] ... Broadcast[Load] + # Reduce init: BufferStore[Ramp] = Broadcast(0) + # Shared memory copy: BufferStore[Ramp] = BufferLoad[Ramp] + # Compute: BufferStore[Ramp] = BufferLoad[Ramp] ... Broadcast[Load] def pre_visit(stmt): if isinstance(stmt, tvm.tir.Broadcast): inside_broadcast[0] = True # Check Broadcast[Imm numbers] or Broadcast[Load] patterns - assert isinstance(stmt.value, (tvm.tir.IntImm, tvm.tir.FloatImm, tvm.tir.Load)) - if isinstance(stmt, tvm.tir.Store): - # Check Store[Ramp] pattern - assert isinstance(stmt.index, tvm.tir.Ramp) - if isinstance(stmt, tvm.tir.Load): - # Check Broadcast[Load] or Load[Ramp] patterns - assert inside_broadcast[0] or isinstance(stmt.index, tvm.tir.Ramp) - # Skip the rest - return stmt + assert isinstance(stmt.value, (tvm.tir.IntImm, tvm.tir.FloatImm, tvm.tir.BufferLoad)) + + if isinstance(stmt, (tvm.tir.BufferStore, tvm.tir.BufferLoad)): + is_ramp_index = isinstance(stmt.indices[-1], tvm.tir.Ramp) + is_vectorized_buffer = re.match(r"^.*x\d+$", stmt.buffer.dtype) + if isinstance(stmt, tvm.tir.BufferLoad): + # Check Broadcast[BufferLoad] or BufferLoad[Ramp] patterns + assert inside_broadcast[0] or is_ramp_index or is_vectorized_buffer + # Skip the rest of the BufferLoad + return stmt + else: + assert is_ramp_index or is_vectorized_buffer + return None def post_visit(stmt): @@ -1037,7 +1037,7 @@ def build(A, C, N, C_N): N, C_N, A, C = get_compute_aligned() a_data, c, kernel_source = build(A, C, N, C_N) # (uint1*)(A + (2)) is a valid vector load - assert "A + (2)" in kernel_source + assert "A + 2" in kernel_source expected = a_data[2 : C_N + 2] assert np.allclose(c, expected), f"expected={expected}\nactual={c}" diff --git a/tests/python/unittest/test_target_codegen_llvm.py b/tests/python/unittest/test_target_codegen_llvm.py index 0e303aaff6eb..45d8b8725c82 100644 --- a/tests/python/unittest/test_target_codegen_llvm.py +++ b/tests/python/unittest/test_target_codegen_llvm.py @@ -51,7 +51,7 @@ def test_llvm_void_intrin(): ib = tvm.tir.ir_builder.create() A = ib.pointer("uint8", name="A") # Create an intrinsic that returns void. - x = tvm.tir.call_llvm_intrin("", "llvm.va_start", tvm.tir.const(1, "uint32"), A) + x = tvm.tir.call_llvm_intrin("", "llvm.va_start", tvm.tir.const(1, "uint32"), A.asobject().data) ib.emit(x) body = ib.get() mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([A], body).with_attr("global_symbol", "main")) @@ -672,13 +672,12 @@ def my_vectorize(): def vectorizer(op): store = op.body idx = tvm.tir.Ramp(tvm.tir.const(0, "int32"), tvm.tir.const(1, "int32"), 8) - all_ones = tvm.tir.const(1, "int32x8") value = store.value b_idx = tvm.tir.Shuffle([idx], [tvm.tir.const(i, "int32") for i in range(7, -1, -1)]) - new_a = tvm.tir.Load("int32x8", value.a.buffer_var, idx, all_ones) - new_b = tvm.tir.Load("int32x8", value.b.buffer_var, b_idx, all_ones) + new_a = tvm.tir.BufferLoad(value.a.buffer, [idx]) + new_b = tvm.tir.BufferLoad(value.b.buffer, [b_idx]) value = new_a + new_b - return tvm.tir.Store(store.buffer_var, new_a + new_b, idx, all_ones) + return tvm.tir.BufferStore(store.buffer, new_a + new_b, [idx]) def _transform(f, *_): return f.with_body( @@ -925,7 +924,7 @@ def threadpool_nested_parallel_loop( T.func_attr({"global_symbol": "main", "tir.noalias": True}) for i in T.parallel(4): for j in T.parallel(4): - T.store(B.data, i * 4 + j, T.load("float32", A.data, i * 4 + j) * 2.0) + B[i, j] = A[i, j] * 2.0 with pytest.raises(tvm.TVMError) as e: tvm.build({"llvm": tvm.IRModule.from_expr(threadpool_nested_parallel_loop)}) diff --git a/tests/python/unittest/test_target_codegen_vulkan.py b/tests/python/unittest/test_target_codegen_vulkan.py index 7b708cbe0c12..bde1ca4d0a58 100644 --- a/tests/python/unittest/test_target_codegen_vulkan.py +++ b/tests/python/unittest/test_target_codegen_vulkan.py @@ -502,10 +502,9 @@ def do_compute(ins, outs): store_index = index_map[store_type] if indirect_indices: - load_index = tvm.tir.expr.Load("int32x4", R, load_index) + load_index = R[load_index] - transfer = tvm.tir.expr.Load("int32x4", A, load_index) - ib.emit(tvm.tir.stmt.Store(B, transfer, store_index)) + B[store_index] = A[load_index] return ib.get() diff --git a/tests/python/unittest/test_tir_analysis_calculate_workspace.py b/tests/python/unittest/test_tir_analysis_calculate_workspace.py index 4b61625014e2..8449782f4589 100644 --- a/tests/python/unittest/test_tir_analysis_calculate_workspace.py +++ b/tests/python/unittest/test_tir_analysis_calculate_workspace.py @@ -26,29 +26,29 @@ def primfunc_global_allocates(placeholder_144: T.handle, placeholder_145: T.handle, placeholder_146: T.handle, T_cast_48: T.handle) -> None: # function attr dict T.func_attr({"global_symbol": "fused_nn_conv2d_add_cast_fixed_point_multiply_clip_cast_cast_13", "tir.noalias": True}) - placeholder_147 = T.match_buffer(placeholder_144, [1, 14, 14, 512], dtype="int16", elem_offset=0, align=128, offset_factor=1) - placeholder_148 = T.match_buffer(placeholder_145, [3, 3, 512, 1], dtype="int16", elem_offset=0, align=128, offset_factor=1) - placeholder_149 = T.match_buffer(placeholder_146, [1, 1, 1, 512], dtype="int32", elem_offset=0, align=128, offset_factor=1) - T_cast_49 = T.match_buffer(T_cast_48, [1, 14, 14, 512], dtype="int16", elem_offset=0, align=128, offset_factor=1) + placeholder_147 = T.match_buffer(placeholder_144, [100352], dtype="int16", elem_offset=0, align=128, offset_factor=1) + placeholder_148 = T.match_buffer(placeholder_145, [4608], dtype="int16", elem_offset=0, align=128, offset_factor=1) + placeholder_149 = T.match_buffer(placeholder_146, [512], dtype="int32", elem_offset=0, align=128, offset_factor=1) + T_cast_49 = T.match_buffer(T_cast_48, [100352], dtype="int16", elem_offset=0, align=128, offset_factor=1) # body PaddedInput_22 = T.allocate([131072], "int16", "global") DepthwiseConv2d_9 = T.allocate([100352], "int32", "global") for i1_29, i2_39, i3_40 in T.grid(16, 16, 512): - PaddedInput_22[(((i1_29*8192) + (i2_39*512)) + i3_40)] = T.if_then_else(((((1 <= i1_29) and (i1_29 < 15)) and (1 <= i2_39)) and (i2_39 < 15)), T.load("int16", placeholder_147.data, ((((i1_29*7168) + (i2_39*512)) + i3_40) - 7680)), T.int16(0), dtype="int16") + PaddedInput_22[(((i1_29*8192) + (i2_39*512)) + i3_40)] = T.if_then_else(((((1 <= i1_29) and (i1_29 < 15)) and (1 <= i2_39)) and (i2_39 < 15)), placeholder_147[((((i1_29*7168) + (i2_39*512)) + i3_40) - 7680)], T.int16(0), dtype="int16") for i_9, j_9, c_9 in T.grid(14, 14, 512): DepthwiseConv2d_9[(((i_9*7168) + (j_9*512)) + c_9)] = 0 for di_9, dj_9 in T.grid(3, 3): - DepthwiseConv2d_9[(((i_9*7168) + (j_9*512)) + c_9)] = (T.load("int32", DepthwiseConv2d_9, (((i_9*7168) + (j_9*512)) + c_9)) + (T.load("int16", PaddedInput_22, (((((i_9*8192) + (di_9*8192)) + (j_9*512)) + (dj_9*512)) + c_9)).astype("int32")*T.load("int16", placeholder_148.data, (((di_9*1536) + (dj_9*512)) + c_9)).astype("int32"))) + DepthwiseConv2d_9[(((i_9*7168) + (j_9*512)) + c_9)] = (DepthwiseConv2d_9[(((i_9*7168) + (j_9*512)) + c_9)] + (PaddedInput_22[(((((i_9*8192) + (di_9*8192)) + (j_9*512)) + (dj_9*512)) + c_9)].astype("int32")*placeholder_148[(((di_9*1536) + (dj_9*512)) + c_9)].astype("int32"))) for ax1_27, ax2_28, ax3_30 in T.grid(14, 14, 512): - DepthwiseConv2d_9[(((ax1_27*7168) + (ax2_28*512)) + ax3_30)] = (T.load("int32", DepthwiseConv2d_9, (((ax1_27*7168) + (ax2_28*512)) + ax3_30)) + T.load("int32", placeholder_149.data, ax3_30)) + DepthwiseConv2d_9[(((ax1_27*7168) + (ax2_28*512)) + ax3_30)] = (DepthwiseConv2d_9[(((ax1_27*7168) + (ax2_28*512)) + ax3_30)] + placeholder_149[ax3_30]) for i1_30, i2_40, i3_41 in T.grid(14, 14, 512): - DepthwiseConv2d_9[(((i1_30*7168) + (i2_40*512)) + i3_41)] = T.q_multiply_shift(T.load("int32", DepthwiseConv2d_9, (((i1_30*7168) + (i2_40*512)) + i3_41)), 1269068532, 31, -4, dtype="int32") + DepthwiseConv2d_9[(((i1_30*7168) + (i2_40*512)) + i3_41)] = T.q_multiply_shift(DepthwiseConv2d_9[(((i1_30*7168) + (i2_40*512)) + i3_41)], 1269068532, 31, -4, dtype="int32") for i1_31, i2_41, i3_42 in T.grid(14, 14, 512): - DepthwiseConv2d_9[(((i1_31*7168) + (i2_41*512)) + i3_42)] = T.max(T.max(T.load("int32", DepthwiseConv2d_9, (((i1_31*7168) + (i2_41*512)) + i3_42)), 255), 0) + DepthwiseConv2d_9[(((i1_31*7168) + (i2_41*512)) + i3_42)] = T.max(T.max(DepthwiseConv2d_9[(((i1_31*7168) + (i2_41*512)) + i3_42)], 255), 0) for ax1_28, ax2_29, ax3_31 in T.grid(14, 14, 512): - PaddedInput_22[(((ax1_28*7168) + (ax2_29*512)) + ax3_31)] = T.load("int32", DepthwiseConv2d_9, (((ax1_28*7168) + (ax2_29*512)) + ax3_31)).astype("uint8") + PaddedInput_22[(((ax1_28*7168) + (ax2_29*512)) + ax3_31)] = DepthwiseConv2d_9[(((ax1_28*7168) + (ax2_29*512)) + ax3_31)].astype("uint8") for ax1_29, ax2_30, ax3_32 in T.grid(14, 14, 512): - T_cast_49.data[(((ax1_29*7168) + (ax2_30*512)) + ax3_32)] = T.load("uint8", PaddedInput_22, (((ax1_29*7168) + (ax2_30*512)) + ax3_32)).astype("int16") + T_cast_49[(((ax1_29*7168) + (ax2_30*512)) + ax3_32)] = PaddedInput_22[(((ax1_29*7168) + (ax2_30*512)) + ax3_32)].astype("int16") # fmt: on @@ -57,36 +57,36 @@ def primfunc_global_allocates(placeholder_144: T.handle, placeholder_145: T.hand def primfunc_local_allocates(placeholder_162: T.handle, placeholder_163: T.handle, placeholder_164: T.handle, T_cast_76: T.handle) -> None: # function attr dict T.func_attr({"global_symbol": "fused_nn_conv2d_add_cast_fixed_point_multiply_clip_cast_cast_9", "tir.noalias": True}) - placeholder_165 = T.match_buffer(placeholder_162, [1, 14, 14, 512], dtype="int16", elem_offset=0, align=128, offset_factor=1) - placeholder_166 = T.match_buffer(placeholder_163, [3, 3, 512, 1], dtype="int16", elem_offset=0, align=128, offset_factor=1) - placeholder_167 = T.match_buffer(placeholder_164, [1, 1, 1, 512], dtype="int32", elem_offset=0, align=128, offset_factor=1) - T_cast_77 = T.match_buffer(T_cast_76, [1, 14, 14, 512], dtype="int16", elem_offset=0, align=128, offset_factor=1) + placeholder_165 = T.match_buffer(placeholder_162, [100352], dtype="int16", elem_offset=0, align=128, offset_factor=1) + placeholder_166 = T.match_buffer(placeholder_163, [4608], dtype="int16", elem_offset=0, align=128, offset_factor=1) + placeholder_167 = T.match_buffer(placeholder_164, [512], dtype="int32", elem_offset=0, align=128, offset_factor=1) + T_cast_77 = T.match_buffer(T_cast_76, [100352], dtype="int16", elem_offset=0, align=128, offset_factor=1) # body - PaddedInput_25 = T.allocate([1, 16, 16, 512], "int16", "global") + PaddedInput_25 = T.allocate([131072], "int16", "global") for i1_35, i2_46, i3_47 in T.grid(16, 16, 512): - PaddedInput_25[(((i1_35*8192) + (i2_46*512)) + i3_47)] = T.if_then_else(((((1 <= i1_35) and (i1_35 < 15)) and (1 <= i2_46)) and (i2_46 < 15)), T.load("int16", placeholder_165.data, ((((i1_35*7168) + (i2_46*512)) + i3_47) - 7680)), T.int16(0), dtype="int16") - T_add_11 = T.allocate([1, 14, 14, 512], "int32", "global") - with T.allocate([1, 14, 14, 512], "int32", "global") as DepthwiseConv2d_11: + PaddedInput_25[(((i1_35*8192) + (i2_46*512)) + i3_47)] = T.if_then_else(((((1 <= i1_35) and (i1_35 < 15)) and (1 <= i2_46)) and (i2_46 < 15)), placeholder_165[((((i1_35*7168) + (i2_46*512)) + i3_47) - 7680)], T.int16(0), dtype="int16") + T_add_11 = T.allocate([100352], "int32", "global") + with T.allocate([100352], "int32", "global") as DepthwiseConv2d_11: for i_11, j_11, c_11 in T.grid(14, 14, 512): DepthwiseConv2d_11[(((i_11*7168) + (j_11*512)) + c_11)] = 0 for di_11, dj_11 in T.grid(3, 3): - DepthwiseConv2d_11[(((i_11*7168) + (j_11*512)) + c_11)] = (T.load("int32", DepthwiseConv2d_11, (((i_11*7168) + (j_11*512)) + c_11)) + (T.load("int16", PaddedInput_25, (((((i_11*8192) + (di_11*8192)) + (j_11*512)) + (dj_11*512)) + c_11)).astype("int32")*T.load("int16", placeholder_166.data, (((di_11*1536) + (dj_11*512)) + c_11)).astype("int32"))) + DepthwiseConv2d_11[(((i_11*7168) + (j_11*512)) + c_11)] = (DepthwiseConv2d_11[(((i_11*7168) + (j_11*512)) + c_11)] + (PaddedInput_25[(((((i_11*8192) + (di_11*8192)) + (j_11*512)) + (dj_11*512)) + c_11)].astype("int32")*placeholder_166[(((di_11*1536) + (dj_11*512)) + c_11)].astype("int32"))) for ax1_44, ax2_45, ax3_47 in T.grid(14, 14, 512): - T_add_11[(((ax1_44*7168) + (ax2_45*512)) + ax3_47)] = (T.load("int32", DepthwiseConv2d_11, (((ax1_44*7168) + (ax2_45*512)) + ax3_47)) + T.load("int32", placeholder_167.data, ax3_47)) - compute_22 = T.allocate([1, 14, 14, 512], "int32", "global") - with T.allocate([1, 14, 14, 512], "int32", "global") as T_cast_78: + T_add_11[(((ax1_44*7168) + (ax2_45*512)) + ax3_47)] = (DepthwiseConv2d_11[(((ax1_44*7168) + (ax2_45*512)) + ax3_47)] + placeholder_167[ax3_47]) + compute_22 = T.allocate([100352], "int32", "global") + with T.allocate([100352], "int32", "global") as T_cast_78: for ax1_45, ax2_46, ax3_48 in T.grid(14, 14, 512): - T_cast_78[(((ax1_45*7168) + (ax2_46*512)) + ax3_48)] = T.load("int32", T_add_11, (((ax1_45*7168) + (ax2_46*512)) + ax3_48)) + T_cast_78[(((ax1_45*7168) + (ax2_46*512)) + ax3_48)] = T_add_11[(((ax1_45*7168) + (ax2_46*512)) + ax3_48)] for i1_36, i2_47, i3_48 in T.grid(14, 14, 512): - compute_22[(((i1_36*7168) + (i2_47*512)) + i3_48)] = T.q_multiply_shift(T.load("int32", T_cast_78, (((i1_36*7168) + (i2_47*512)) + i3_48)), 1948805937, 31, -5, dtype="int32") - T_cast_79 = T.allocate([1, 14, 14, 512], "uint8", "global") - with T.allocate([1, 14, 14, 512], "int32", "global") as compute_23: + compute_22[(((i1_36*7168) + (i2_47*512)) + i3_48)] = T.q_multiply_shift(T_cast_78[(((i1_36*7168) + (i2_47*512)) + i3_48)], 1948805937, 31, -5, dtype="int32") + T_cast_79 = T.allocate([100352], "uint8", "global") + with T.allocate([100352], "int32", "global") as compute_23: for i1_37, i2_48, i3_49 in T.grid(14, 14, 512): - compute_23[(((i1_37*7168) + (i2_48*512)) + i3_49)] = T.max(T.max(T.load("int32", compute_22, (((i1_37*7168) + (i2_48*512)) + i3_49)), 255), 0) + compute_23[(((i1_37*7168) + (i2_48*512)) + i3_49)] = T.max(T.max(compute_22[(((i1_37*7168) + (i2_48*512)) + i3_49)], 255), 0) for ax1_46, ax2_47, ax3_49 in T.grid(14, 14, 512): - T_cast_79[(((ax1_46*7168) + (ax2_47*512)) + ax3_49)] = T.load("int32", compute_23, (((ax1_46*7168) + (ax2_47*512)) + ax3_49)).astype("uint8") + T_cast_79[(((ax1_46*7168) + (ax2_47*512)) + ax3_49)] = compute_23[(((ax1_46*7168) + (ax2_47*512)) + ax3_49)].astype("uint8") for ax1_47, ax2_48, ax3_50 in T.grid(14, 14, 512): - T_cast_77.data[(((ax1_47*7168) + (ax2_48*512)) + ax3_50)] = T.load("uint8", T_cast_79, (((ax1_47*7168) + (ax2_48*512)) + ax3_50)).astype("int16") + T_cast_77[(((ax1_47*7168) + (ax2_48*512)) + ax3_50)] = T_cast_79[(((ax1_47*7168) + (ax2_48*512)) + ax3_50)].astype("int16") # fmt: on diff --git a/tests/python/unittest/test_tir_analysis_detect_buffer_access_lca.py b/tests/python/unittest/test_tir_analysis_detect_buffer_access_lca.py index 1a0dfd09a2df..49121614ffa0 100644 --- a/tests/python/unittest/test_tir_analysis_detect_buffer_access_lca.py +++ b/tests/python/unittest/test_tir_analysis_detect_buffer_access_lca.py @@ -54,10 +54,10 @@ def buffer_opaque_access(b: T.handle, c: T.handle) -> None: T.writes(B[0:16, 0:16]) A = T.allocate([256], "float32", "global") for i, j in T.grid(16, 16): - T.store(A, i * 16 + j, 1) + A[i * 16 + j] = 1 for i in range(0, 16): for j in range(0, 16): - T.evaluate(T.load("float32", A, i * 16 + j)) + T.evaluate(A[i * 16 + j]) for j in range(0, 16): T.evaluate(T.tvm_fill_fragment(B.data, 16, 16, 16, 0, T.float32(0), dtype="handle")) @@ -70,7 +70,7 @@ def buffer_opaque_access(b: T.handle, c: T.handle) -> None: @T.prim_func def lca_is_func_root(a: T.handle) -> None: A = T.match_buffer(a, [0, 0], "float32") - A.data[0] = 1.0 + A[0, 0] = 1.0 @T.prim_func diff --git a/tests/python/unittest/test_tir_buffer.py b/tests/python/unittest/test_tir_buffer.py index 422d730160b5..e790ffc199e5 100644 --- a/tests/python/unittest/test_tir_buffer.py +++ b/tests/python/unittest/test_tir_buffer.py @@ -82,7 +82,15 @@ def test_buffer_vload(): n = te.size_var("n") Ab = tvm.tir.decl_buffer((m, n), "float32", elem_offset=100) load = Ab.vload([2, 3]) - tvm.testing.assert_prim_expr_equal(load.index, n * 2 + 103) + tvm.ir.assert_structural_equal(load.indices, [2, 3]) + + +def test_buffer_offset_of(): + m = te.size_var("m") + n = te.size_var("n") + Ab = tvm.tir.decl_buffer((m, n), "float32", elem_offset=100) + offset = Ab.offset_of([2, 3]) + tvm.ir.assert_structural_equal(offset, [n * 2 + 103]) def test_buffer_vload_nullptr(): @@ -124,32 +132,32 @@ def assert_simplified_equal(index_simplified, index_direct): idxd = tvm.tir.indexdiv idxm = tvm.tir.indexmod # Test Case1 - index_simplified = A_stride.vload( + index_simplified = A_stride.offset_of( (idxd(idxm(k0, k1), s), idxm(idxm(k0, k1), s) + idxd(k0, k1) * k1) ) - index_direct = A_stride.vload((0, k0)) + index_direct = A_stride.offset_of((0, k0)) assert_simplified_equal(index_simplified, index_direct) # Test Case2 - index_simplified = A.vload( + index_simplified = A.offset_of( (idxd(idxm(k0, idxd(k1, s)), n), idxm(idxm(k0, idxd(k1, s)), n) + idxm(k0, k1)) ) - index_direct = A.vload((0, idxm(k0, k1) + idxm(k0, idxd(k1, s)))) + index_direct = A.offset_of((0, idxm(k0, k1) + idxm(k0, idxd(k1, s)))) assert_simplified_equal(index_simplified, index_direct) # Test Case3 - index_simplified = A.vload( + index_simplified = A.offset_of( ( idxd((idxd(k0, idxd(k1, s)) * idxd(k1, s)), n) + idxd(idxm(k0, idxd(k1, s)), n), idxm((idxd(k0, idxd(k1, s)) * idxd(k1, s)), n) + idxm(idxm(k0, idxd(k1, s)), n), ) ) - index_direct = A.vload((0, k0)) + index_direct = A.offset_of((0, k0)) assert_simplified_equal(index_simplified, index_direct) # Test Case4 (not able to simplify) - index_simplified = A.vload( + index_simplified = A.offset_of( (idxd(idxm(k0, idxd(k1, s)), n), idxm(idxm(k0, idxd(k1, n)), n) + idxm(k0, k1)) ) - index_direct = A.vload( + index_direct = A.offset_of( (0, idxd(idxm(k0, idxd(k1, s)), n) * n + (idxm(idxm(k0, idxd(k1, n)), n) + idxm(k0, k1))) ) assert_simplified_equal(index_simplified, index_direct) @@ -160,7 +168,7 @@ def assert_simplified_equal(index_simplified, index_direct): j = te.size_var("j") k = te.size_var("k") - index_simplified = B.vload( + index_simplified = B.offset_of( ( idxd(idxd(idxd((i * 50176 + j * 28672 + k), 1024), 14), 14), idxm(idxd(idxd((i * 50176 + j * 28672 + k), 1024), 14), 14), @@ -168,7 +176,7 @@ def assert_simplified_equal(index_simplified, index_direct): idxm((i * 50176 + j * 28672 + k), 1024), ) ) - index_direct = B.vload((0, 0, 0, (i * 50176 + j * 28672 + k))) + index_direct = B.offset_of((0, 0, 0, (i * 50176 + j * 28672 + k))) assert_simplified_equal(index_simplified, index_direct) diff --git a/tests/python/unittest/test_tir_constructor.py b/tests/python/unittest/test_tir_constructor.py index 00aba46ba431..dcd642c3b9ec 100644 --- a/tests/python/unittest/test_tir_constructor.py +++ b/tests/python/unittest/test_tir_constructor.py @@ -87,13 +87,14 @@ def test_expr_constructor(): assert x.false_value == b assert x.condition == a - buffer_var = te.var("x", dtype="handle") - x = tvm.tir.Load("float32", buffer_var, 1, a) - assert isinstance(x, tvm.tir.Load) + buffer_var = tvm.tir.Var("buf", tvm.ir.PointerType(tvm.ir.PrimType("float32"))) + buffer = tvm.tir.decl_buffer([16], "float32", data=buffer_var) + x = tvm.tir.BufferLoad(buffer, [1]) + assert isinstance(x, tvm.tir.BufferLoad) assert x.dtype == "float32" - assert x.buffer_var == buffer_var - assert x.index.value == 1 - assert x.predicate == a + assert x.buffer == buffer + assert x.buffer.data == buffer_var + assert list(x.indices) == [1] x = tvm.tir.Ramp(1, 2, 10) assert isinstance(x, tvm.tir.Ramp) @@ -126,7 +127,6 @@ def test_expr_constructor(): def test_stmt_constructor(): v = te.var("aa") - buffer_var = te.var("buf", dtype="handle") nop = tvm.tir.Evaluate(1) x = tvm.tir.LetStmt(v, 1, tvm.tir.Evaluate(1)) assert isinstance(x, tvm.tir.LetStmt) @@ -148,10 +148,13 @@ def test_stmt_constructor(): assert x.extent.value == 10 assert x.body == nop - x = tvm.tir.Store(buffer_var, 1, 10, tvm.tir.const(1, "uint1")) - assert isinstance(x, tvm.tir.Store) - assert x.buffer_var == buffer_var - assert x.index.value == 10 + buffer_var = tvm.tir.Var("buf", tvm.ir.PointerType(tvm.ir.PrimType("uint1"))) + buffer = tvm.tir.decl_buffer([16], "uint1", data=buffer_var) + x = tvm.tir.BufferStore(buffer, 1, [10]) + assert isinstance(x, tvm.tir.BufferStore) + assert x.buffer == buffer + assert x.buffer.data == buffer_var + assert list(x.indices) == [10] assert x.value.value == 1 buffer_var = tvm.tir.Var("buf", tvm.ir.PointerType(tvm.ir.PrimType("float32"))) diff --git a/tests/python/unittest/test_tir_intrin.py b/tests/python/unittest/test_tir_intrin.py index 3e9e7fd33fd9..b800a6d2109c 100644 --- a/tests/python/unittest/test_tir_intrin.py +++ b/tests/python/unittest/test_tir_intrin.py @@ -236,10 +236,7 @@ def test_tir_fma(A: T.handle, B: T.handle, C: T.handle, d: T.handle) -> None: ) # body for i in T.serial(0, n): - d_1.data[(i * stride_3)] = ( - T.load("float32", A_1.data, (i * stride)) - * T.load("float32", B_1.data, (i * stride_1)) - ) + T.load("float32", C_1.data, (i * stride_2)) + d_1[(i * stride_3)] = (A_1[(i * stride)] * B_1[(i * stride_1)]) + C_1[(i * stride_2)] def test_fma(): diff --git a/tests/python/unittest/test_tir_ir_builder.py b/tests/python/unittest/test_tir_ir_builder.py index 5b123e883849..9438da17ede2 100644 --- a/tests/python/unittest/test_tir_ir_builder.py +++ b/tests/python/unittest/test_tir_ir_builder.py @@ -56,8 +56,8 @@ def test_if(): body = body.body assert isinstance(body, tvm.tir.IfThenElse) assert isinstance(body.condition, tvm.tir.EQ) - assert isinstance(body.then_case.index, tvm.tir.Var) - assert body.else_case.index.value == 0 + assert isinstance(body.then_case.indices[0], tvm.tir.Var) + assert list(body.else_case.indices) == [0] def test_prefetch(): diff --git a/tests/python/unittest/test_tir_lower_match_buffer.py b/tests/python/unittest/test_tir_lower_match_buffer.py index 5ca9cf0da3c9..93b7caf9cdde 100644 --- a/tests/python/unittest/test_tir_lower_match_buffer.py +++ b/tests/python/unittest/test_tir_lower_match_buffer.py @@ -465,7 +465,7 @@ def fail_match_load(a: T.handle) -> None: T.reads(A[i, j]) T.writes([]) sub_A = T.match_buffer(A[i, j], ()) - T.evaluate(T.load("float32", sub_A.data, 0)) + T.evaluate(sub_A[()]) @T.prim_func @@ -476,7 +476,7 @@ def fail_match_store(a: T.handle) -> None: T.reads([]) T.writes(A[i, j]) sub_A = T.match_buffer(A[i, j], ()) - sub_A.data[0] = 1 + sub_A[()] = 1 @T.prim_func diff --git a/tests/python/unittest/test_tir_nodes.py b/tests/python/unittest/test_tir_nodes.py index 96224ef6fe55..b4295411bf9b 100644 --- a/tests/python/unittest/test_tir_nodes.py +++ b/tests/python/unittest/test_tir_nodes.py @@ -16,7 +16,7 @@ # under the License. import pytest import tvm -from tvm import te +from tvm import te, ir import numpy as np @@ -89,11 +89,18 @@ def test_ir(): def test_ir2(): + buf_size = te.var("size") x = te.var("n") - a = te.var("array", "handle") - st = tvm.tir.Store(a, x + 1, 1) - assert isinstance(st, tvm.tir.Store) - assert st.buffer_var == a + + storage_type = ir.PrimType("int32") + handle_type = ir.PointerType(storage_type) + array = te.var("array", handle_type) + buf = tvm.tir.decl_buffer([buf_size], "int32", data=array) + + st = tvm.tir.BufferStore(buf, x + 1, [1]) + assert isinstance(st, tvm.tir.BufferStore) + assert st.buffer == buf + assert st.buffer.data == array def test_let(): diff --git a/tests/python/unittest/test_tir_ptx_mma.py b/tests/python/unittest/test_tir_ptx_mma.py index c304e818ef05..8f653c614d42 100644 --- a/tests/python/unittest/test_tir_ptx_mma.py +++ b/tests/python/unittest/test_tir_ptx_mma.py @@ -52,20 +52,18 @@ def gemm_mma_m8n8k4_row_col_fp64pf64fp64(a: T.handle, b: T.handle, c: T.handle): "fp64", "fp64", "fp64", - MultiA, + MultiA.data, 0, - MultiB, + MultiB.data, 0, - Accum, + Accum.data, 0, False, dtype="float64", ) ) for mma_accum_c_id in range(2): - C[(tx % 32) // 4, (tx % 32) % 4 * 2 + mma_accum_c_id] = T.load( - "float64", Accum, mma_accum_c_id - ) + C[(tx % 32) // 4, (tx % 32) % 4 * 2 + mma_accum_c_id] = Accum[mma_accum_c_id] @tvm.testing.requires_cuda @@ -132,11 +130,11 @@ def gemm_mma_m8n8k4_row_row_fp16fp16fp16(a: T.handle, b: T.handle, c: T.handle): "fp16", "fp16", "fp16", - MultiA, + MultiA.data, 0, - MultiB, + MultiB.data, 0, - Accum, + Accum.data, 0, False, dtype="float16", @@ -146,7 +144,7 @@ def gemm_mma_m8n8k4_row_row_fp16fp16fp16(a: T.handle, b: T.handle, c: T.handle): C[ ((tx % 32) % 4) + (4 * ((((tx % 32) // 16 + (tx % 32) % 16 // 4 * 2)) % 4)), mma_accum_c_id % 4 + (4 * ((tx % 32) % 16 // 8)) + mma_accum_c_id // 4 * 8, - ] = T.load("float16", Accum, mma_accum_c_id) + ] = Accum[mma_accum_c_id] @tvm.testing.requires_cuda @@ -213,11 +211,11 @@ def gemm_mma_m8n8k4_row_row_fp16fp16fp32(a: T.handle, b: T.handle, c: T.handle): "fp16", "fp16", "fp32", - MultiA, + MultiA.data, 0, - MultiB, + MultiB.data, 0, - Accum, + Accum.data, 0, False, dtype="float32", @@ -233,7 +231,7 @@ def gemm_mma_m8n8k4_row_row_fp16fp16fp32(a: T.handle, b: T.handle, c: T.handle): + (tx % 32) % 16 // 8 * 4 + mma_accum_c_id % 2 + mma_accum_c_id // 4 * 8, - ] = T.load("float32", Accum, mma_accum_c_id) + ] = Accum[mma_accum_c_id] @tvm.testing.requires_cuda @@ -294,20 +292,18 @@ def gemm_mma_m8n8k16_row_col_s8s8s32(a: T.handle, b: T.handle, c: T.handle): "int8", "int8", "int32", - MultiA, + MultiA.data, 0, - MultiB, + MultiB.data, 0, - Accum, + Accum.data, 0, False, dtype="int32", ) ) for mma_accum_c_id in range(2): - C[(tx % 32) // 4, (tx % 32) % 4 * 2 + mma_accum_c_id] = T.load( - "int32", Accum, mma_accum_c_id - ) + C[(tx % 32) // 4, (tx % 32) % 4 * 2 + mma_accum_c_id] = Accum[mma_accum_c_id] # This test uses mma instructions that are not available on NVCC 10.1. @@ -372,20 +368,18 @@ def gemm_mma_m8n8k16_row_col_s8u8s32(a: T.handle, b: T.handle, c: T.handle): "int8", "uint8", "int32", - MultiA, + MultiA.data, 0, - MultiB, + MultiB.data, 0, - Accum, + Accum.data, 0, False, dtype="int32", ) ) for mma_accum_c_id in range(2): - C[(tx % 32) // 4, (tx % 32) % 4 * 2 + mma_accum_c_id] = T.load( - "int32", Accum, mma_accum_c_id - ) + C[(tx % 32) // 4, (tx % 32) % 4 * 2 + mma_accum_c_id] = Accum[mma_accum_c_id] # This test uses mma instructions that are not available on NVCC 10.1. @@ -450,20 +444,18 @@ def gemm_mma_m8n8k32_row_col_s4s4s32(a: T.handle, b: T.handle, c: T.handle): "int4", "int4", "int32", - MultiA, + MultiA.data, 0, - MultiB, + MultiB.data, 0, - Accum, + Accum.data, 0, False, dtype="int32", ) ) for mma_accum_c_id in range(2): - C[(tx % 32) // 4, (tx % 32) % 4 * 2 + mma_accum_c_id] = T.load( - "int32", Accum, mma_accum_c_id - ) + C[(tx % 32) // 4, (tx % 32) % 4 * 2 + mma_accum_c_id] = Accum[mma_accum_c_id] # This test uses mma instructions that are not available on NVCC 10.1. @@ -520,20 +512,18 @@ def gemm_mma_m8n8k32_row_col_s4u4s32(a: T.handle, b: T.handle, c: T.handle): "int4", "uint4", "int32", - MultiA, + MultiA.data, 0, - MultiB, + MultiB.data, 0, - Accum, + Accum.data, 0, False, dtype="int32", ) ) for mma_accum_c_id in range(2): - C[(tx % 32) // 4, (tx % 32) % 4 * 2 + mma_accum_c_id] = T.load( - "int32", Accum, mma_accum_c_id - ) + C[(tx % 32) // 4, (tx % 32) % 4 * 2 + mma_accum_c_id] = Accum[mma_accum_c_id] # This test uses mma instructions that are not available on NVCC 10.1. @@ -594,20 +584,20 @@ def gemm_mma_m16n8k8_row_col_fp16fp16fp32(a: T.handle, b: T.handle, c: T.handle) "fp16", "fp16", "fp32", - MultiA, + MultiA.data, 0, - MultiB, + MultiB.data, 0, - Accum, + Accum.data, 0, False, dtype="float32", ) ) for mma_accum_c_id in range(4): - C[ - (tx % 32) // 4 + mma_accum_c_id // 2 * 8, (tx % 32) % 4 * 2 + mma_accum_c_id % 2 - ] = T.load("float32", Accum, mma_accum_c_id) + C[(tx % 32) // 4 + mma_accum_c_id // 2 * 8, (tx % 32) % 4 * 2 + mma_accum_c_id % 2] = Accum[ + mma_accum_c_id + ] @tvm.testing.requires_cuda @@ -674,11 +664,11 @@ def gemm_mma_m16n8k16_row_col_fp16fp16fp16(a: T.handle, b: T.handle, c: T.handle "fp16", "fp16", "fp16", - MultiA, + MultiA.data, 0, - MultiB, + MultiB.data, 0, - Accum, + Accum.data, 0, False, dtype="float16", @@ -688,7 +678,7 @@ def gemm_mma_m16n8k16_row_col_fp16fp16fp16(a: T.handle, b: T.handle, c: T.handle C[ (tx % 32) // 4 + mma_accum_c_id // 2 * 8, (tx % 32) % 4 * 2 + mma_accum_c_id % 2, - ] = T.load("float16", Accum, mma_accum_c_id) + ] = Accum[mma_accum_c_id] @tvm.testing.requires_cuda @@ -756,11 +746,11 @@ def gemm_mma_m16n8k16_row_col_fp16fp16fp32(a: T.handle, b: T.handle, c: T.handle "fp16", "fp16", "fp32", - MultiA, + MultiA.data, 0, - MultiB, + MultiB.data, 0, - Accum, + Accum.data, 0, False, dtype="float32", @@ -770,7 +760,7 @@ def gemm_mma_m16n8k16_row_col_fp16fp16fp32(a: T.handle, b: T.handle, c: T.handle C[ (tx % 32) // 4 + mma_accum_c_id // 2 * 8, (tx % 32) % 4 * 2 + mma_accum_c_id % 2, - ] = T.load("float32", Accum, mma_accum_c_id) + ] = Accum[mma_accum_c_id] @tvm.testing.requires_cuda @@ -838,11 +828,11 @@ def gemm_mma_m16n8k16_row_col_s8s8s32(a: T.handle, b: T.handle, c: T.handle): "int8", "int8", "int32", - MultiA, + MultiA.data, 0, - MultiB, + MultiB.data, 0, - Accum, + Accum.data, 0, False, dtype="int32", @@ -852,7 +842,7 @@ def gemm_mma_m16n8k16_row_col_s8s8s32(a: T.handle, b: T.handle, c: T.handle): C[ (tx % 32) // 4 + mma_accum_c_id // 2 * 8, (tx % 32) % 4 * 2 + mma_accum_c_id % 2, - ] = T.load("int32", Accum, mma_accum_c_id) + ] = Accum[mma_accum_c_id] @tvm.testing.requires_cuda @@ -920,11 +910,11 @@ def gemm_mma_m16n8k16_row_col_s8u8s32(a: T.handle, b: T.handle, c: T.handle): "int8", "uint8", "int32", - MultiA, + MultiA.data, 0, - MultiB, + MultiB.data, 0, - Accum, + Accum.data, 0, False, dtype="int32", @@ -934,7 +924,7 @@ def gemm_mma_m16n8k16_row_col_s8u8s32(a: T.handle, b: T.handle, c: T.handle): C[ (tx % 32) // 4 + mma_accum_c_id // 2 * 8, (tx % 32) % 4 * 2 + mma_accum_c_id % 2, - ] = T.load("int32", Accum, mma_accum_c_id) + ] = Accum[mma_accum_c_id] @tvm.testing.requires_cuda @@ -1002,11 +992,11 @@ def gemm_mma_m16n8k32_row_col_s8s8s32(a: T.handle, b: T.handle, c: T.handle): "int8", "int8", "int32", - MultiA, + MultiA.data, 0, - MultiB, + MultiB.data, 0, - Accum, + Accum.data, 0, False, dtype="int32", @@ -1016,7 +1006,7 @@ def gemm_mma_m16n8k32_row_col_s8s8s32(a: T.handle, b: T.handle, c: T.handle): C[ (tx % 32) // 4 + mma_accum_c_id // 2 * 8, (tx % 32) % 4 * 2 + mma_accum_c_id % 2, - ] = T.load("int32", Accum, mma_accum_c_id) + ] = Accum[mma_accum_c_id] @tvm.testing.requires_cuda @@ -1084,11 +1074,11 @@ def gemm_mma_m16n8k32_row_col_s8u8s32(a: T.handle, b: T.handle, c: T.handle): "int8", "uint8", "int32", - MultiA, + MultiA.data, 0, - MultiB, + MultiB.data, 0, - Accum, + Accum.data, 0, False, dtype="int32", @@ -1098,7 +1088,7 @@ def gemm_mma_m16n8k32_row_col_s8u8s32(a: T.handle, b: T.handle, c: T.handle): C[ (tx % 32) // 4 + mma_accum_c_id // 2 * 8, (tx % 32) % 4 * 2 + mma_accum_c_id % 2, - ] = T.load("int32", Accum, mma_accum_c_id) + ] = Accum[mma_accum_c_id] @tvm.testing.requires_cuda @@ -1166,11 +1156,11 @@ def gemm_mma_m16n8k64_row_col_s4s4s32(a: T.handle, b: T.handle, c: T.handle): "int4", "int4", "int32", - MultiA, + MultiA.data, 0, - MultiB, + MultiB.data, 0, - Accum, + Accum.data, 0, False, dtype="int32", @@ -1180,7 +1170,7 @@ def gemm_mma_m16n8k64_row_col_s4s4s32(a: T.handle, b: T.handle, c: T.handle): C[ (tx % 32) // 4 + mma_accum_c_id // 2 * 8, (tx % 32) % 4 * 2 + mma_accum_c_id % 2, - ] = T.load("int32", Accum, mma_accum_c_id) + ] = Accum[mma_accum_c_id] @tvm.testing.requires_cuda @@ -1240,11 +1230,11 @@ def gemm_mma_m16n8k64_row_col_s4u4s32(a: T.handle, b: T.handle, c: T.handle): "int4", "uint4", "int32", - MultiA, + MultiA.data, 0, - MultiB, + MultiB.data, 0, - Accum, + Accum.data, 0, False, dtype="int32", @@ -1254,7 +1244,7 @@ def gemm_mma_m16n8k64_row_col_s4u4s32(a: T.handle, b: T.handle, c: T.handle): C[ (tx % 32) // 4 + mma_accum_c_id // 2 * 8, (tx % 32) % 4 * 2 + mma_accum_c_id % 2, - ] = T.load("int32", Accum, mma_accum_c_id) + ] = Accum[mma_accum_c_id] @tvm.testing.requires_cuda @@ -1314,11 +1304,11 @@ def gemm_mma_m16n8k256_row_col_b1b1s32(a: T.handle, b: T.handle, c: T.handle): "int1", "int1", "int32", - MultiA, + MultiA.data, 0, - MultiB, + MultiB.data, 0, - Accum, + Accum.data, 0, False, dtype="int32", @@ -1328,7 +1318,7 @@ def gemm_mma_m16n8k256_row_col_b1b1s32(a: T.handle, b: T.handle, c: T.handle): C[ (tx % 32) // 4 + mma_accum_c_id // 2 * 8, (tx % 32) % 4 * 2 + mma_accum_c_id % 2, - ] = T.load("int32", Accum, mma_accum_c_id) + ] = Accum[mma_accum_c_id] @tvm.testing.requires_cuda diff --git a/tests/python/unittest/test_tir_schedule_cache_read_write.py b/tests/python/unittest/test_tir_schedule_cache_read_write.py index cdbf6cb7f11e..203fb15ef2c2 100644 --- a/tests/python/unittest/test_tir_schedule_cache_read_write.py +++ b/tests/python/unittest/test_tir_schedule_cache_read_write.py @@ -80,7 +80,7 @@ def opaque_access(a: T.handle, b: T.handle, c: T.handle, d: T.handle) -> None: vi, vj = T.axis.remap("SS", [i, j]) T.reads(A[vi, vj]) T.writes(D[vi, vj]) - D.data[vi * 128 + vj] = T.load("float16", A.data, vi * 128 + vj) + D[vi, vj] = A[vi, vj] for i, j in T.grid(8, 8): with T.block("opaque"): vi, vj = T.axis.remap("SS", [i, j]) @@ -288,7 +288,7 @@ def cache_read_opaque_access(a: T.handle, b: T.handle, c: T.handle, d: T.handle) vi, vj = T.axis.remap("SS", [i, j]) T.reads(A_global[vi, vj]) T.writes(D[vi, vj]) - D.data[vi * 128 + vj] = T.load("float16", A_global.data, vi * 128 + vj) + D[vi, vj] = A_global[vi, vj] for i, j in T.grid(8, 8): with T.block("opaque"): vi, vj = T.axis.remap("SS", [i, j]) @@ -518,7 +518,7 @@ def cache_write_opaque_access(a: T.handle, b: T.handle, c: T.handle, d: T.handle vi, vj = T.axis.remap("SS", [i, j]) T.reads(A[vi, vj]) T.writes(D_global[vi, vj]) - D_global.data[vi * 128 + vj] = T.load("float16", A.data, vi * 128 + vj) + D_global[vi, vj] = A[vi, vj] for i, j in T.grid(8, 8): with T.block("opaque"): vi, vj = T.axis.remap("SS", [i, j]) diff --git a/tests/python/unittest/test_tir_schedule_compute_inline.py b/tests/python/unittest/test_tir_schedule_compute_inline.py index 5cc36c0df878..f8d767da4645 100644 --- a/tests/python/unittest/test_tir_schedule_compute_inline.py +++ b/tests/python/unittest/test_tir_schedule_compute_inline.py @@ -183,7 +183,12 @@ def opaque_access_load(a: T.handle, c: T.handle) -> None: vi, vj = T.axis.remap("SS", [i, j]) T.reads(B[0:128, 0:128]) T.writes(C[0:128, 0:128]) - C[vi, vj] = T.load("float32", B.data, vi * 128 + vj) + 1.0 + T.evaluate( + T.tvm_access_ptr( + T.type_annotation(dtype="float32"), B.data, 0, 128, "r", dtype="handle" + ) + ) + C[vi, vj] = B[vi, vj] + 1.0 @T.prim_func @@ -200,8 +205,17 @@ def opaque_access_store(a: T.handle, c: T.handle) -> None: vi, vj = T.axis.remap("SS", [i, j]) T.reads(B[0:128, 0:128]) T.writes(C[0:128, 0:128]) - T.store(C.data, vi * 128 + vj, B[vi, vj] + 1.0) - C[vi, vj] = T.load("float32", B.data, vi * 16 + vj) + 1.0 + T.evaluate( + T.tvm_access_ptr( + T.type_annotation(dtype="float32"), B.data, 0, 128, "r", dtype="handle" + ) + ) + T.evaluate( + T.tvm_access_ptr( + T.type_annotation(dtype="float32"), C.data, 0, 128, "w", dtype="handle" + ) + ) + C[vi, vj] = B[vi, vj] + 1.0 @T.prim_func diff --git a/tests/python/unittest/test_tir_schedule_reorder.py b/tests/python/unittest/test_tir_schedule_reorder.py index fd2d82d1ff1f..f62a316f8013 100644 --- a/tests/python/unittest/test_tir_schedule_reorder.py +++ b/tests/python/unittest/test_tir_schedule_reorder.py @@ -153,7 +153,7 @@ def opaque_access(a: T.handle, b: T.handle) -> None: vi, vj = T.axis.remap("SS", [i, j]) T.reads([]) T.writes([A[0:16, 0:16]]) - T.store(A.data, vi * 16 + vj, 1) + A[vi, vj] = 1 for i, j in T.grid(16, 16): with T.block("B"): vi, vj = T.axis.remap("SS", [i, j]) @@ -171,7 +171,7 @@ def opaque_access_reorder(a: T.handle, b: T.handle) -> None: vi, vj = T.axis.remap("SS", [i, j]) T.reads([]) T.writes([A[0:16, 0:16]]) - T.store(A.data, vi * 16 + vj, 1) + A[vi, vj] = 1 for j, i in T.grid(16, 16): with T.block("B"): vi, vj = T.axis.remap("SS", [i, j]) diff --git a/tests/python/unittest/test_tir_schedule_split_fuse.py b/tests/python/unittest/test_tir_schedule_split_fuse.py index 2c3b431298e7..b5ab45a505fe 100644 --- a/tests/python/unittest/test_tir_schedule_split_fuse.py +++ b/tests/python/unittest/test_tir_schedule_split_fuse.py @@ -273,7 +273,7 @@ def opaque_access(a: T.handle, b: T.handle) -> None: vi, vj = T.axis.remap("SS", [i, j]) T.reads([]) T.writes([A[0:16, 0:16]]) - T.store(A.data, vi * 16 + vj, 1) + A[vi, vj] = 1 for i, j in T.grid(16, 16): with T.block("B"): vi, vj = T.axis.remap("SS", [i, j]) @@ -292,7 +292,7 @@ def opaque_access_fused(a: T.handle, b: T.handle) -> None: vj = T.axis.S(16, T.floormod(i_j_fused, 16)) T.reads([]) T.writes([A[0:16, 0:16]]) - T.store(A.data, ((vi * 16) + vj), 1, 1) + A[vi, vj] = 1 for i_j_fused in T.serial(0, 256): with T.block("B"): vi = T.axis.S(16, T.floordiv(i_j_fused, 16)) @@ -312,7 +312,7 @@ def opaque_access_split(a: T.handle, b: T.handle) -> None: vj = T.axis.S(16, j0 * 4 + j1) T.reads([]) T.writes([A[0:16, 0:16]]) - T.store(A.data, ((vi * 16) + vj), 1, 1) + A[vi, vj] = 1 for i, j0, j1 in T.grid(16, 4, 4): with T.block("B"): vi = T.axis.S(16, i) diff --git a/tests/python/unittest/test_tir_transform_combine_context_call.py b/tests/python/unittest/test_tir_transform_combine_context_call.py index 191aec4b4641..3271e6e2569a 100644 --- a/tests/python/unittest/test_tir_transform_combine_context_call.py +++ b/tests/python/unittest/test_tir_transform_combine_context_call.py @@ -29,10 +29,10 @@ def device_context(dev_id): n = te.var("n") A = ib.allocate("float32", n, name="A", scope="global") with ib.for_range(0, n, name="i") as i: - ib.emit(tvm.tir.call_extern("int32", "fadd", device_context(0), A)) + ib.emit(tvm.tir.call_extern("int32", "fadd", device_context(0), A.asobject().data)) with ib.for_range(0, 10, name="j") as j: - ib.emit(tvm.tir.call_extern("int32", "fadd", device_context(1), A)) - ib.emit(tvm.tir.call_extern("int32", "fadd", device_context(0), A)) + ib.emit(tvm.tir.call_extern("int32", "fadd", device_context(1), A.asobject().data)) + ib.emit(tvm.tir.call_extern("int32", "fadd", device_context(0), A.asobject().data)) body = ib.get() mod = tvm.IRModule({"func": tvm.tir.PrimFunc([dev_type, n], body)}) diff --git a/tests/python/unittest/test_tir_transform_common_subexpr_elim.py b/tests/python/unittest/test_tir_transform_common_subexpr_elim.py index b01a9e652f77..17c0cbdd99c6 100644 --- a/tests/python/unittest/test_tir_transform_common_subexpr_elim.py +++ b/tests/python/unittest/test_tir_transform_common_subexpr_elim.py @@ -45,7 +45,7 @@ def test_cse(): 2, tvm.tir.SeqStmt( [ - tvm.tir.Store(buffer.data, z1 + z2, i1), + tvm.tir.BufferStore(buffer, z1 + z2, [i1]), tvm.tir.LetStmt( x, 1, @@ -56,7 +56,7 @@ def test_cse(): a, (x + y) + (z1 + z2), tvm.tir.LetStmt( - b, (x + y) + z3, tvm.tir.Store(buffer.data, a + b, i2) + b, (x + y) + z3, tvm.tir.BufferStore(buffer, a + b, [i2]) ), ), ), @@ -96,7 +96,7 @@ def test_cse(): body = body.body - assert isinstance(body[0], tvm.tir.Store) + assert isinstance(body[0], tvm.tir.BufferStore) assert isinstance(body[1], tvm.tir.LetStmt) body = body[1] @@ -130,7 +130,7 @@ def test_cse(): # Check that the replacement has been done correctly! assert tvm.ir.structural_equal(body.value, cse_var_2 + z3) - assert isinstance(body.body, tvm.tir.Store) + assert isinstance(body.body, tvm.tir.BufferStore) # First specific test for if nodes : Some duplicated computations appear only in one branch (here the Then branch), not in both branches. @@ -160,9 +160,9 @@ def test_cse_ifNode_1(): tvm.tir.IfThenElse( b, tvm.tir.SeqStmt( - [tvm.tir.Store(buffer.data, y + z, i1), tvm.tir.Store(buffer.data, y + z, i2)] + [tvm.tir.BufferStore(buffer, y + z, [i1]), tvm.tir.BufferStore(buffer, y + z, [i2])] ), - tvm.tir.Store(buffer.data, y, i3), + tvm.tir.BufferStore(buffer, y, [i3]), ), ) @@ -217,11 +217,11 @@ def test_cse_ifNode_2(): b, tvm.tir.SeqStmt( [ - tvm.tir.Store(buffer.data, y + z, i1), # (y+z) is present in the Then branch - tvm.tir.Store(buffer.data, y, i2), + tvm.tir.BufferStore(buffer, y + z, [i1]), # (y+z) is present in the Then branch + tvm.tir.BufferStore(buffer, y, [i2]), ] ), - tvm.tir.Store(buffer.data, y + z, i3), # and also present in the Else branch + tvm.tir.BufferStore(buffer, y + z, [i3]), # and also present in the Else branch ), ) @@ -258,9 +258,9 @@ def test_cse_cascade(): # Mem[i3] = x+y body = tvm.tir.SeqStmt( [ - tvm.tir.Store(buffer.data, (x + y) + z, i1), - tvm.tir.Store(buffer.data, (x + y) + z, i2), - tvm.tir.Store(buffer.data, (x + y), i3), + tvm.tir.BufferStore(buffer, (x + y) + z, [i1]), + tvm.tir.BufferStore(buffer, (x + y) + z, [i2]), + tvm.tir.BufferStore(buffer, (x + y), [i3]), ] ) @@ -292,9 +292,9 @@ def test_cse_cascade(): body = body.body assert isinstance(body, tvm.tir.SeqStmt) - assert isinstance(body[0], tvm.tir.Store) - assert isinstance(body[1], tvm.tir.Store) - assert isinstance(body[2], tvm.tir.Store) + assert isinstance(body[0], tvm.tir.BufferStore) + assert isinstance(body[1], tvm.tir.BufferStore) + assert isinstance(body[2], tvm.tir.BufferStore) store1 = body[0] store2 = body[1] diff --git a/tests/python/unittest/test_tir_transform_compact_buffer_region.py b/tests/python/unittest/test_tir_transform_compact_buffer_region.py index 6c42ed7d9280..ee0e7c9605bf 100644 --- a/tests/python/unittest/test_tir_transform_compact_buffer_region.py +++ b/tests/python/unittest/test_tir_transform_compact_buffer_region.py @@ -80,7 +80,8 @@ def unschedulable_func(a: T.handle, c: T.handle) -> None: T.writes(C[i, 0:16]) B = T.alloc_buffer((16, 16), "float32") for j in range(0, 16): - T.store(B.data, i * 16 + j, A[i, j] + 1.0) + T.evaluate(T.call_extern("dummy_extern_function", B.data, dtype="int32")) + B[i, j] = A[i, j] + 1.0 for j in range(0, 16): C[i, j] = B[i, j] * 2.0 @@ -251,7 +252,7 @@ def complex_func(a: T.handle, c: T.handle, n: T.int32) -> None: for k in range(4, 8): D[k, j] = 1.0 for k in range(2, 4): - T.store(B.data, j, A[i, j] + D[k, j]) + B[i, j] = A[i, j] + D[k, j] for j in range(3, 5): with T.block() as []: T.reads(B[i, j]) @@ -281,7 +282,7 @@ def compacted_complex_func(a: T.handle, c: T.handle, n: T.int32) -> None: for k in range(4, 8): D[k - 2, 0] = 1.0 for k in range(2, 4): - T.store(B.data, j, A[i, j] + D[k - 2, 0]) + B[0, j] = A[i, j] + D[k - 2, 0] for j in range(3, 5): with T.block() as []: T.reads(B[0, j]) @@ -476,13 +477,15 @@ def opaque_access_annotated_func(a: T.handle) -> None: # no annotation, opaque access will cover full region T.reads([]) T.writes([]) - T.store(B.data, i, "float32", A[i]) + T.evaluate(T.call_extern("opaque_extern_function", A.data, B.data, dtype="int32")) + B[i] = A[i] with T.block(): # treat opaque access only access annotated regions, even if # they are not compatible with actual buffer accesses. T.reads([B[i]]) T.writes([C[i : i + 9]]) - T.store(C.data, i, T.load("float32", B.data, i)) + T.evaluate(T.call_extern("opaque_extern_function", B.data, C.data, dtype="int32")) + C[i] = B[i] @T.prim_func @@ -496,13 +499,15 @@ def compacted_opaque_access_annotated_func(a: T.handle) -> None: # no annotation, opaque access will cover full region T.reads([]) T.writes([]) - T.store(B.data, i, "float32", A[i]) + T.evaluate(T.call_extern("opaque_extern_function", A.data, B.data, dtype="int32")) + B[i] = A[i] with T.block(): # treat opaque access only access annotated regions, even if # they are not compatible with actual buffer accesses. T.reads([B[i]]) T.writes([C[i : i + 9]]) - T.store(C.data, i, T.load("float32", B.data, i)) + T.evaluate(T.call_extern("opaque_extern_function", B.data, C.data, dtype="int32")) + C[i] = B[i] @T.prim_func diff --git a/tests/python/unittest/test_tir_transform_convert_for_loops_serial.py b/tests/python/unittest/test_tir_transform_convert_for_loops_serial.py index a91fa2591e00..38431705611b 100644 --- a/tests/python/unittest/test_tir_transform_convert_for_loops_serial.py +++ b/tests/python/unittest/test_tir_transform_convert_for_loops_serial.py @@ -26,22 +26,22 @@ def fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast_2(placeholder_30: T.handle, placeholder_31: T.handle, placeholder_32: T.handle, T_cast_8: T.handle) -> None: # function attr dict T.func_attr({"global_symbol": "fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast_2", "tir.noalias": True}) - placeholder_33 = T.match_buffer(placeholder_30, [1, 28, 28, 192], dtype="int16", elem_offset=0, align=128, offset_factor=1) - placeholder_34 = T.match_buffer(placeholder_31, [1, 1, 192, 16], dtype="int16", elem_offset=0, align=128, offset_factor=1) - placeholder_35 = T.match_buffer(placeholder_32, [1, 1, 1, 16], dtype="int32", elem_offset=0, align=128, offset_factor=1) - T_cast_9 = T.match_buffer(T_cast_8, [1, 28, 28, 16], dtype="int16", elem_offset=0, align=128, offset_factor=1) + placeholder_33 = T.match_buffer(placeholder_30, [150528], dtype="int16", elem_offset=0, align=128, offset_factor=1) + placeholder_34 = T.match_buffer(placeholder_31, [3072], dtype="int16", elem_offset=0, align=128, offset_factor=1) + placeholder_35 = T.match_buffer(placeholder_32, [16], dtype="int32", elem_offset=0, align=128, offset_factor=1) + T_cast_9 = T.match_buffer(T_cast_8, [12544], dtype="int16", elem_offset=0, align=128, offset_factor=1) # body - PaddedInput_3 = T.allocate([1, 28, 28, 192], "int16", "global") + PaddedInput_3 = T.allocate([150528], "int16", "global") for i0_i1_fused_3 in T.parallel(0, 28): for i2_3, i3_3 in T.grid(28, 192): - T.store(PaddedInput_3, (((i0_i1_fused_3*5376) + (i2_3*192)) + i3_3), T.load("int16", placeholder_33.data, (((i0_i1_fused_3*5376) + (i2_3*192)) + i3_3)), True) + PaddedInput_3[(((i0_i1_fused_3*5376) + (i2_3*192)) + i3_3) ] = placeholder_33[(((i0_i1_fused_3*5376) + (i2_3*192)) + i3_3)] for ax0_ax1_fused_ax2_fused_3 in T.parallel(0, 784): for ax3_2 in T.serial(0, 16): - Conv2dOutput_3 = T.allocate([1, 1, 1, 1], "int32", "global") - T.store(Conv2dOutput_3, 0, 0, True) + Conv2dOutput_3 = T.allocate([1], "int32", "global") + Conv2dOutput_3[0] = 0 for rc_3 in T.serial(0, 192): - T.store(Conv2dOutput_3, 0, (T.load("int32", Conv2dOutput_3, 0) + (T.cast(T.load("int16", PaddedInput_3, ((ax0_ax1_fused_ax2_fused_3*192) + rc_3)), "int32")*T.cast(T.load("int16", placeholder_34.data, ((rc_3*16) + ax3_2)), "int32"))), True) - T.store(T_cast_9.data, ((ax0_ax1_fused_ax2_fused_3*16) + ax3_2), T.cast(T.cast(T.max(T.min(T.q_multiply_shift((T.load("int32", Conv2dOutput_3, 0) + T.load("int32", placeholder_35.data, ax3_2)), 1764006585, 31, -7, dtype="int32"), 255), 0), "uint8"), "int16"), True) + Conv2dOutput_3[0] = (Conv2dOutput_3[0] + (T.cast(PaddedInput_3[((ax0_ax1_fused_ax2_fused_3*192) + rc_3)], "int32")*T.cast(placeholder_34[((rc_3*16) + ax3_2)], "int32"))) + T_cast_9[((ax0_ax1_fused_ax2_fused_3*16) + ax3_2)] = T.cast(T.cast(T.max(T.min(T.q_multiply_shift((Conv2dOutput_3[0] + placeholder_35[ax3_2]), 1764006585, 31, -7, dtype="int32"), 255), 0), "uint8"), "int16") # fmt: on diff --git a/tests/python/unittest/test_tir_transform_extract_constants.py b/tests/python/unittest/test_tir_transform_extract_constants.py index 74144f252ade..9636a9bdde4c 100644 --- a/tests/python/unittest/test_tir_transform_extract_constants.py +++ b/tests/python/unittest/test_tir_transform_extract_constants.py @@ -28,7 +28,7 @@ def constant1(a: T.handle) -> None: B = T.alloc_buffer((10), "int32") K = T.allocate_const([1, 1, 1, 1, 1, 1, 1, 1, 1, 1], "int32", [10]) for x in T.serial(0, 10): - B[x] = A[x] + T.load("int32", K, x) + B[x] = A[x] + K[x] @T.prim_func def constant2(a: T.handle) -> None: @@ -36,7 +36,7 @@ def constant2(a: T.handle) -> None: B = T.alloc_buffer((10), "int32") K = T.allocate_const([1, 1, 1, 1, 1, 1, 1, 1, 1, 1], "int32", [10]) for x in T.serial(0, 10): - B[x] = A[x] + T.load("int32", K, x) + B[x] = A[x] + K[x] @T.prim_func def constant3(a: T.handle) -> None: @@ -44,7 +44,7 @@ def constant3(a: T.handle) -> None: B = T.alloc_buffer((10), "int32") K = T.allocate_const([1, 2, 3, 1, 1, 1, 1, 1, 1, 1], "int32", [10]) for x in T.serial(0, 10): - B[x] = A[x] + T.load("int32", K, x) + B[x] = A[x] + K[x] def test_const_extraction(): diff --git a/tests/python/unittest/test_tir_transform_flatten_buffer.py b/tests/python/unittest/test_tir_transform_flatten_buffer.py index ca3d4aa70d0b..68b1ad338964 100644 --- a/tests/python/unittest/test_tir_transform_flatten_buffer.py +++ b/tests/python/unittest/test_tir_transform_flatten_buffer.py @@ -50,14 +50,16 @@ def compacted_elementwise_func(a: T.handle, c: T.handle) -> None: @T.prim_func def flattened_elementwise_func(a: T.handle, c: T.handle) -> None: - A = T.match_buffer(a, (16, 16), "float32") - C = T.match_buffer(c, (16, 16), "float32") + A = T.match_buffer(a, 256, "float32") + C = T.match_buffer(c, 256, "float32") + T.preflattened_buffer(A, (16, 16), dtype="float32", data=A.data) + T.preflattened_buffer(C, (16, 16), dtype="float32", data=C.data) for i in T.serial(0, 16): B_new = T.allocate([16], "float32", "global") for j in T.serial(0, 16): - B_new[j] = T.load("float32", A.data, ((i * 16) + j)) + 1.0 + B_new[j] = A[((i * 16) + j)] + 1.0 for j in T.serial(0, 16): - C.data[((i * 16) + j)] = T.load("float32", B_new, j) * 2.0 + C[((i * 16) + j)] = B_new[j] * 2.0 @T.prim_func @@ -85,8 +87,10 @@ def compacted_gpu_func(a: T.handle, c: T.handle) -> None: @T.prim_func def flattened_gpu_func(a: T.handle, c: T.handle) -> None: - A = T.match_buffer(a, (16, 16), "float32") - C = T.match_buffer(c, (16, 16), "float32") + A = T.match_buffer(a, 256, "float32") + C = T.match_buffer(c, 256, "float32") + T.preflattened_buffer(A, (16, 16), dtype="float32", data=A.data) + T.preflattened_buffer(C, (16, 16), dtype="float32", data=C.data) i0 = T.env_thread("blockIdx.x") i1 = T.env_thread("threadIdx.x") @@ -97,9 +101,9 @@ def flattened_gpu_func(a: T.handle, c: T.handle) -> None: T.launch_thread(i2, 2) B = T.allocate([16], "float32", "local") for j in range(0, 16): - B[j] = T.load("float32", A.data, i0 * 64 + i1 * 32 + i2 * 16 + j) + 1.0 + B[j] = A[i0 * 64 + i1 * 32 + i2 * 16 + j] + 1.0 for j in range(0, 16): - C.data[i0 * 64 + i1 * 32 + i2 * 16 + j] = T.load("float32", B, j) * 2.0 + C[i0 * 64 + i1 * 32 + i2 * 16 + j] = B[j] * 2.0 @T.prim_func @@ -126,15 +130,17 @@ def compacted_symbolic_func(a: T.handle, c: T.handle, n: T.int32, m: T.int32) -> @T.prim_func def flattened_symbolic_func(a: T.handle, c: T.handle, n: T.int32, m: T.int32) -> None: - A = T.match_buffer(a, (n, m), "float32") - C = T.match_buffer(c, (n, m), "float32") + A = T.match_buffer(a, n * m, "float32") + C = T.match_buffer(c, n * m, "float32") + T.preflattened_buffer(A, (n, m), "float32", data=A.data) + T.preflattened_buffer(C, (n, m), "float32", data=C.data) for i in range(0, n): B = T.allocate([m], "float32", "global") for j in range(0, m): - B[j] = T.load("float32", A.data, i * m + j) + 1.0 + B[j] = A[i * m + j] + 1.0 for j in range(0, m): - C.data[i * m + j] = T.load("float32", B, j) * 2.0 + C[i * m + j] = B[j] * 2.0 @T.prim_func @@ -154,10 +160,12 @@ def compacted_predicate_func(a: T.handle, c: T.handle) -> None: def flattened_predicate_func(a: T.handle, c: T.handle) -> None: A = T.match_buffer(a, (32), "float32") C = T.match_buffer(c, (32), "float32") + T.preflattened_buffer(A, (32), "float32", data=A.data) + T.preflattened_buffer(C, (32), "float32", data=C.data) for i, j in T.grid(5, 7): if i * 7 + j < 32: - C.data[i * 7 + j] = T.load("float32", A.data, i * 7 + j) + 1.0 + C[i * 7 + j] = A[i * 7 + j] + 1.0 @T.prim_func @@ -176,9 +184,11 @@ def compacted_unit_loop_func(a: T.handle, c: T.handle) -> None: def flattened_unit_loop_func(a: T.handle, c: T.handle) -> None: A = T.match_buffer(a, (32), "float32") C = T.match_buffer(c, (32), "float32") + T.preflattened_buffer(A, (32), "float32", data=A.data) + T.preflattened_buffer(C, (32), "float32", data=C.data) for x, z in T.grid(4, 8): - C.data[x * 8 + z] = T.load("float32", A.data, x * 8 + z) + 1.0 + C[x * 8 + z] = A[x * 8 + z] + 1.0 @T.prim_func @@ -201,13 +211,15 @@ def compacted_multi_alloc_func(a: T.handle, d: T.handle) -> None: def flattened_multi_alloc_func(a: T.handle, d: T.handle) -> None: A = T.match_buffer(a, (32), "float32") D = T.match_buffer(d, (32), "float32") + T.preflattened_buffer(A, (32), "float32", data=A.data) + T.preflattened_buffer(D, (32), "float32", data=D.data) for i in range(0, 32): B = T.allocate((32,), "float32", "global") C = T.allocate((32,), "float32", "global") - B[i] = T.load("float32", A.data, i) + 1.0 - C[i] = T.load("float32", A.data, i) + T.load("float32", B, i) - D.data[i] = T.load("float32", C, i) * 2.0 + B[i] = A[i] + 1.0 + C[i] = A[i] + B[i] + D[i] = C[i] * 2.0 @T.prim_func @@ -235,16 +247,18 @@ def compacted_strided_buffer_func(a: T.handle, c: T.handle) -> None: @T.prim_func def flattened_strided_buffer_func(a: T.handle, c: T.handle) -> None: - A = T.match_buffer(a, (16, 16), "float32") - C = T.match_buffer(c, (16, 16), "float32") + A = T.match_buffer(a, (256,), "float32") + C = T.match_buffer(c, (256,), "float32") + T.preflattened_buffer(A, [16, 16], dtype="float32", data=A.data) + T.preflattened_buffer(C, [16, 16], dtype="float32", data=C.data) for i0 in T.serial(0, 4): B_new = T.allocate([68], "float32", "global") for i1 in T.serial(0, 4): for j in T.serial(0, 16): - B_new[i1 * 17 + j] = T.load("float32", A.data, i0 * 64 + i1 * 16 + j) + 1.0 + B_new[i1 * 17 + j] = A[i0 * 64 + i1 * 16 + j] + 1.0 for i1 in T.serial(0, 4): for j in T.serial(0, 16): - C.data[i0 * 64 + i1 * 16 + j] = T.load("float32", B_new, i1 * 17 + j) * 2.0 + C[i0 * 64 + i1 * 16 + j] = B_new[i1 * 17 + j] * 2.0 @T.prim_func diff --git a/tests/python/unittest/test_tir_transform_inject_double_buffer.py b/tests/python/unittest/test_tir_transform_inject_double_buffer.py index 9b37bcaaacbc..0f4cc00f0702 100644 --- a/tests/python/unittest/test_tir_transform_inject_double_buffer.py +++ b/tests/python/unittest/test_tir_transform_inject_double_buffer.py @@ -30,7 +30,7 @@ def test_double_buffer(): with ib.for_range(0, n) as i: B = ib.allocate("float32", m, name="B", scope="shared") with ib.new_scope(): - ib.scope_attr(B.asobject(), "double_buffer_scope", 1) + ib.scope_attr(B.asobject().data, "double_buffer_scope", 1) with ib.for_range(0, m) as j: B[j] = A[i * 4 + j] with ib.for_range(0, m) as j: @@ -48,7 +48,7 @@ def test_double_buffer(): stmt = mod["db"].body assert isinstance(stmt.body, tvm.tir.Allocate) - assert stmt.body.extents[0].value == 2 + assert list(stmt.body.extents) == [m * 2] f = tvm.tir.transform.ThreadSync("shared")(mod)["db"] count = [0] diff --git a/tests/python/unittest/test_tir_transform_inject_rolling_buffer.py b/tests/python/unittest/test_tir_transform_inject_rolling_buffer.py index 2298fe94da18..4f70639eada9 100644 --- a/tests/python/unittest/test_tir_transform_inject_rolling_buffer.py +++ b/tests/python/unittest/test_tir_transform_inject_rolling_buffer.py @@ -37,33 +37,41 @@ def _tile_nd(s, tensor, tile): return outer_indices, inner_indices -def _lower_schedule(sch, args): - sch = sch.normalize() - bounds = tvm.te.schedule.InferBound(sch) - stmt = tvm.te.schedule.ScheduleOps(sch, bounds) +@tvm.tir.transform.prim_func_pass(opt_level=0) +def remove_rolling_buffer_attr(func, mod, ctx): + def unwrap(node): + if isinstance(node, tvm.tir.AttrStmt) and node.attr_key == "rolling_buffer_scope": + return node.body + else: + return node + + return func.with_body( + tvm.tir.stmt_functor.ir_transform( + func.body, None, postorder=unwrap, only_enable=["tir.AttrStmt"] + ) + ) - compact = tvm.te.schedule.VerifyCompactBuffer(stmt) - binds, arg_list = get_binds(args, compact, None) - func = tvm.te.schedule.SchedulePostProcToPrimFunc(arg_list, stmt, binds) - func = func.with_attr("global_symbol", "main") - func = func.with_attr("tir.noalias", True) - mod = tvm.IRModule({"main": func}) - return mod +@tvm.tir.transform.prim_func_pass(opt_level=0) +def verify_no_rolling_buffer_attr(func, mod, ctx): + def verify(node): + if isinstance(node, tvm.tir.AttrStmt): + assert node.attr_key != "rolling_buffer_scope", "Failed to lower rolling buffers" + tvm.tir.stmt_functor.post_order_visit(func.body, verify) -def _verify_schedule(sch, inputs, output): - mod = _lower_schedule(sch, inputs + [output]) - mods = [] - mods.append(mod) - mod = tvm.tir.transform.InjectRollingBuffer()(mod) + return func - def _check(stmt): - if isinstance(stmt, tvm.tir.AttrStmt): - assert stmt.attr_key != "rolling_buffer_scope", "Failed to lower rolling buffers" - tvm.tir.stmt_functor.post_order_visit(mod["main"].body, _check) - mods.append(mod) +def _verify_schedule(sch, inputs, output): + user_pass_lists = [ + [(0, remove_rolling_buffer_attr), (0, verify_no_rolling_buffer_attr)], + [(0, tvm.tir.transform.InjectRollingBuffer()), (0, verify_no_rolling_buffer_attr)], + ] + built_funcs = [] + for user_pass_list in user_pass_lists: + with tvm.transform.PassContext(config={"tir.add_lower_pass": user_pass_list}): + built_funcs.append(tvm.build(sch, inputs + [output])) outputs = [] ctx = tvm.cpu(0) @@ -75,15 +83,9 @@ def _check(stmt): ) shape = [i.value for i in output.shape] out = tvm.nd.array(np.zeros(shape, dtype="int8"), ctx) - for mod in mods: - mod = tvm.tir.transform.StorageFlatten(64)(mod) - mod = tvm.tir.transform.NarrowDataType(32)(mod) - mod = tvm.tir.transform.LoopPartition()(mod) - mod = tvm.tir.transform.StorageRewrite()(mod) - # Build for CPU execution - f = tvm.build(mod) - f(*input_data, out) - outputs.append(out.asnumpy()) + for func in built_funcs: + func(*input_data, out) + outputs.append(out.numpy()) np.testing.assert_equal(outputs[0], outputs[1]) diff --git a/tests/python/unittest/test_tir_transform_inject_virtual_thread.py b/tests/python/unittest/test_tir_transform_inject_virtual_thread.py index 673267a9b1fa..1d13acce369a 100644 --- a/tests/python/unittest/test_tir_transform_inject_virtual_thread.py +++ b/tests/python/unittest/test_tir_transform_inject_virtual_thread.py @@ -17,8 +17,10 @@ import tvm from tvm import te +vthread_name = tvm.testing.parameter("vthread", "cthread") -def test_vthread(): + +def test_vthread(vthread_name): dtype = "int64" n = 100 m = 4 @@ -35,7 +37,7 @@ def get_vthread(name): ib.scope_attr(ty, "virtual_thread", nthread) B = ib.allocate("float32", m, name="B", scope="shared") B[i] = A[i * nthread + tx] - bbuffer = tvm.tir.decl_buffer((m,), dtype=B.dtype, data=B.asobject()) + bbuffer = B.asobject() ib.emit( tvm.tir.call_extern( "int32", @@ -47,20 +49,19 @@ def get_vthread(name): C[i * nthread + tx] = B[i] + 1 return ib.get() - stmt = tvm.tir.transform.InjectVirtualThread()( - tvm.IRModule.from_expr(tvm.tir.PrimFunc([], get_vthread("vthread"))) - )["main"] - - assert stmt.body.body.extents[0].value == 2 + if vthread_name == "vthread": + B_expected_alloc = m * nthread + elif vthread_name == "cthread": + B_expected_alloc = m * nthread * nthread stmt = tvm.tir.transform.InjectVirtualThread()( - tvm.IRModule.from_expr(tvm.tir.PrimFunc([], get_vthread("cthread"))) + tvm.IRModule.from_expr(tvm.tir.PrimFunc([], get_vthread(vthread_name))) )["main"] - assert len(stmt.body.body.extents) == 3 + assert list(stmt.body.body.extents) == [B_expected_alloc] -def test_vthread_extern(): +def test_vthread_extern(vthread_name): dtype = "int64" n = 100 m = 4 @@ -76,9 +77,9 @@ def get_vthread(name): A = ib.allocate("float32", m, name="A", scope="shared") B = ib.allocate("float32", m, name="B", scope="shared") C = ib.allocate("float32", m, name="C", scope="shared") - cbuffer = tvm.tir.decl_buffer((m,), dtype=C.dtype, data=C.asobject()) - abuffer = tvm.tir.decl_buffer((m,), dtype=A.dtype, data=A.asobject()) - bbuffer = tvm.tir.decl_buffer((m,), dtype=B.dtype, data=B.asobject()) + abuffer = A.asobject() + bbuffer = B.asobject() + cbuffer = C.asobject() A[tx] = tx + 1.0 B[ty] = ty + 1.0 ib.emit( @@ -92,13 +93,19 @@ def get_vthread(name): ) return ib.get() + if vthread_name == "vthread": + A_expected_alloc = m * nthread + elif vthread_name == "cthread": + A_expected_alloc = m * nthread * nthread + + C_expected_alloc = m * nthread * nthread + stmt = tvm.tir.transform.InjectVirtualThread()( - tvm.IRModule.from_expr(tvm.tir.PrimFunc([], get_vthread("cthread"))) + tvm.IRModule.from_expr(tvm.tir.PrimFunc([], get_vthread(vthread_name))) )["main"] - assert stmt.body.body.extents[0].value == 2 - assert stmt.body.body.body.body.extents[0].value == 2 - assert len(stmt.body.body.body.body.extents) == 3 + assert list(stmt.body.body.extents) == [A_expected_alloc] + assert list(stmt.body.body.body.body.extents) == [C_expected_alloc] def test_vthread_if_then_else(): diff --git a/tests/python/unittest/test_tir_transform_instrument_bound_checkers.py b/tests/python/unittest/test_tir_transform_instrument_bound_checkers.py index 2c9997f6fe78..9f61b5a3920a 100644 --- a/tests/python/unittest/test_tir_transform_instrument_bound_checkers.py +++ b/tests/python/unittest/test_tir_transform_instrument_bound_checkers.py @@ -111,7 +111,9 @@ def test_in_bounds_vectorize_llvm(): f = tvm.build(s, [A, C], "llvm") dev = tvm.cpu(0) # launch the kernel. - a = tvm.nd.empty((n,), A.dtype).copyfrom(np.random.uniform(size=(n, lanes))) + a = tvm.nd.empty((n,), A.dtype).copyfrom( + np.random.uniform(size=[n] + ([] if lanes == 1 else [lanes])) + ) c = tvm.nd.empty((n,), C.dtype, dev) f(a, c) tvm.testing.assert_allclose(c.numpy(), a.numpy() + 1) @@ -161,7 +163,7 @@ def check_attr_stmt(x): if ( isinstance(x, tvm.tir.AttrStmt) and x.attr_key == "buffer_bound" - and str(x.value) == str(n) + and tvm.ir.structural_equal(x.value.args, [n]) ): return True return False diff --git a/tests/python/unittest/test_tir_transform_ir_utils.py b/tests/python/unittest/test_tir_transform_ir_utils.py index b6752ee3efd3..8030b77f9946 100644 --- a/tests/python/unittest/test_tir_transform_ir_utils.py +++ b/tests/python/unittest/test_tir_transform_ir_utils.py @@ -16,15 +16,18 @@ # under the License. import pytest import tvm -from tvm import tir +from tvm import tir, ir def test_convert_ssa(): + dtype = "int32" zero = tir.const(0) nop = tir.Evaluate(zero) - v = tir.Var("i1", "int32") + var_type = ir.PointerType(ir.PrimType(dtype)) + v = tir.Var("i1", var_type) + buf = tir.decl_buffer([16], dtype=dtype, data=v) for_stmt = tir.For(v, zero, zero, tir.ForKind.SERIAL, nop) - load = tir.Evaluate(tir.Load("int32", v, zero)) + load = tir.Evaluate(tir.BufferLoad(buf, [zero])) seq = tir.SeqStmt([for_stmt, for_stmt, load]) func = tir.PrimFunc([], seq) mod = tvm.IRModule({"main": func}) diff --git a/tests/python/unittest/test_tir_transform_loop_partition.py b/tests/python/unittest/test_tir_transform_loop_partition.py index 8a71169ca78c..6cfe96664d89 100644 --- a/tests/python/unittest/test_tir_transform_loop_partition.py +++ b/tests/python/unittest/test_tir_transform_loop_partition.py @@ -540,15 +540,17 @@ def test_simple_rfactor(): @T.prim_func -def partitioned_concat(a: T.handle, b: T.handle, c: T.handle) -> None: +def partitioned_concat( + A: T.Buffer[(16,), "float32"], B: T.Buffer[(16,), "float32"], C: T.Buffer[(32,), "float32"] +) -> None: T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}) - A = T.match_buffer(a, [16], dtype="float32") - B = T.match_buffer(b, [16], dtype="float32") - C = T.match_buffer(c, [32], dtype="float32") + T.preflattened_buffer(A, [16], data=A.data) + T.preflattened_buffer(B, [16], data=B.data) + T.preflattened_buffer(C, [32], data=C.data) for i in T.serial(0, 16): - T.store(C.data, i, T.load("float32", A.data, i), True) + C[i] = A[i] for i in T.serial(0, 16): - T.store(C.data, i + 16, T.load("float32", B.data, i + 16), True) + C[i + 16] = B[i + 16] def test_explicit_partition_hint(): @@ -568,64 +570,42 @@ def test_explicit_partition_hint(): @T.prim_func def partitioned_concat_3( - placeholder: T.Buffer[(1, 64, 28, 28), "int8"], - placeholder_1: T.Buffer[(1, 32, 28, 28), "int8"], - placeholder_2: T.Buffer[(1, 32, 28, 28), "int8"], - T_concat: T.Buffer[(1, 128, 28, 28), "int8"], + placeholder: T.Buffer[(50176,), "int8"], + placeholder_1: T.Buffer[(25088,), "int8"], + placeholder_2: T.Buffer[(25088,), "int8"], + T_concat: T.Buffer[(100352,), "int8"], ) -> None: + T.preflattened_buffer(placeholder, [1, 64, 28, 28], "int8", data=placeholder.data) + T.preflattened_buffer(placeholder_1, [1, 32, 28, 28], "int8", data=placeholder_1.data) + T.preflattened_buffer(placeholder_2, [1, 32, 28, 28], "int8", data=placeholder_2.data) + T.preflattened_buffer(T_concat, [1, 128, 28, 28], "int8", data=T_concat.data) for i1, i2, i3 in T.grid(64, 28, 28): - T.store( - T_concat.data, - i1 * 784 + i2 * 28 + i3, - T.load("int8", placeholder.data, i1 * 784 + i2 * 28 + i3), - True, - ) + T_concat[i1 * 784 + i2 * 28 + i3] = placeholder[i1 * 784 + i2 * 28 + i3] for i1, i2, i3 in T.grid(32, 28, 28): - T.store( - T_concat.data, - i1 * 784 + i2 * 28 + i3 + 50176, - T.load("int8", placeholder_1.data, i1 * 784 + i2 * 28 + i3), - True, - ) + T_concat[i1 * 784 + i2 * 28 + i3 + 50176] = placeholder_1[i1 * 784 + i2 * 28 + i3] for i1, i2, i3 in T.grid(32, 28, 28): - T.store( - T_concat.data, - i1 * 784 + i2 * 28 + i3 + 75264, - T.load("int8", placeholder_2.data, i1 * 784 + i2 * 28 + i3), - True, - ) + T_concat[i1 * 784 + i2 * 28 + i3 + 75264] = placeholder_2[i1 * 784 + i2 * 28 + i3] @T.prim_func def concat_func_3( - placeholder: T.Buffer[(1, 64, 28, 28), "int8"], - placeholder_1: T.Buffer[(1, 32, 28, 28), "int8"], - placeholder_2: T.Buffer[(1, 32, 28, 28), "int8"], - T_concat: T.Buffer[(1, 128, 28, 28), "int8"], + placeholder: T.Buffer[(50176,), "int8"], + placeholder_1: T.Buffer[(25088,), "int8"], + placeholder_2: T.Buffer[(25088,), "int8"], + T_concat: T.Buffer[(100352,), "int8"], ) -> None: + T.preflattened_buffer(placeholder, (1, 64, 28, 28), "int8", data=placeholder.data) + T.preflattened_buffer(placeholder_1, (1, 32, 28, 28), "int8", data=placeholder_1.data) + T.preflattened_buffer(placeholder_2, (1, 32, 28, 28), "int8", data=placeholder_2.data) + T.preflattened_buffer(T_concat, (1, 128, 28, 28), "int8", data=T_concat.data) for i1 in T.serial(128, annotations={"pragma_loop_partition_hint": 1}): for i2, i3 in T.grid(28, 28): if 96 <= i1: - T.store( - T_concat.data, - i1 * 784 + i2 * 28 + i3, - T.load("int8", placeholder_2.data, i1 * 784 + i2 * 28 + i3 - 75264), - True, - ) + T_concat[i1 * 784 + i2 * 28 + i3] = placeholder_2[i1 * 784 + i2 * 28 + i3 - 75264] if 64 <= i1 and i1 < 96: - T.store( - T_concat.data, - i1 * 784 + i2 * 28 + i3, - T.load("int8", placeholder_1.data, i1 * 784 + i2 * 28 + i3 - 50176), - True, - ) + T_concat[i1 * 784 + i2 * 28 + i3] = placeholder_1[i1 * 784 + i2 * 28 + i3 - 50176] if i1 < 64: - T.store( - T_concat.data, - i1 * 784 + i2 * 28 + i3, - T.load("int8", placeholder.data, i1 * 784 + i2 * 28 + i3), - True, - ) + T_concat[i1 * 784 + i2 * 28 + i3] = placeholder[i1 * 784 + i2 * 28 + i3] def test_condition_mutually_exclusive(): diff --git a/tests/python/unittest/test_tir_transform_lower_cross_thread_reduction.py b/tests/python/unittest/test_tir_transform_lower_cross_thread_reduction.py index 5b3d7283f14f..2be3bb181150 100644 --- a/tests/python/unittest/test_tir_transform_lower_cross_thread_reduction.py +++ b/tests/python/unittest/test_tir_transform_lower_cross_thread_reduction.py @@ -82,7 +82,7 @@ def lowered_loop_split(a: T.handle, b: T.handle) -> None: T.uint32(1), normal_reduce_temp0[0], True, - reduce_temp0.data, + reduce_temp0[0], ki, dtype="handle", ) @@ -127,7 +127,7 @@ def lowered_no_normal_reduction(a: T.handle, b: T.handle) -> None: ) T.evaluate( T.tvm_thread_allreduce( - T.uint32(1), A[vi, vk], True, reduce_temp0.data, k, dtype="handle" + T.uint32(1), A[vi, vk], True, reduce_temp0[0], k, dtype="handle" ) ) with T.block("B_write_back"): @@ -174,7 +174,7 @@ def lowered_two_bound_loops(a: T.handle, b: T.handle) -> None: ) T.evaluate( T.tvm_thread_allreduce( - T.uint32(1), A[vi, vk], True, reduce_temp0.data, ko, ki, dtype="handle" + T.uint32(1), A[vi, vk], True, reduce_temp0[0], ko, ki, dtype="handle" ) ) with T.block("B_write_back"): @@ -253,7 +253,7 @@ def lowered_multiple_blocks_under_reduction_loop(a: T.handle, b: T.handle) -> No T.uint32(1), normal_reduce_temp0[0], True, - reduce_temp0.data, + reduce_temp0[0], k0o, dtype="handle", ) @@ -315,7 +315,7 @@ def lowered_with_block_predicate(a: T.handle, b: T.handle) -> None: T.uint32(1), normal_reduce_temp0[0], True, - reduce_temp0.data, + reduce_temp0[0], ki, dtype="handle", ) @@ -418,7 +418,7 @@ def lowered_single_reduction_loop_with_block_predicate( T.uint32(1), in_thread_0[0], True, - cross_thread_0.data, + cross_thread_0[0], ax1_1, dtype="handle", ) @@ -456,7 +456,7 @@ def lowered_single_reduction_loop_with_block_predicate( T.uint32(1), in_thread_1[0], True, - cross_thread_1.data, + cross_thread_1[0], ax1_1, dtype="handle", ) @@ -516,7 +516,7 @@ def lowered_reducer_max(a: T.handle, b: T.handle) -> None: ) T.evaluate( T.tvm_thread_allreduce( - T.uint32(1), A[vi, vk], True, reduce_temp0.data, k, dtype="handle" + T.uint32(1), A[vi, vk], True, reduce_temp0[0], k, dtype="handle" ) ) with T.block("B_write_back"): @@ -556,9 +556,7 @@ def lowered_zero_rank_buffer(a: T.handle, b: T.handle) -> None: T.reinterpret(T.uint64(0), dtype="handle"), ) T.evaluate( - T.tvm_thread_allreduce( - T.uint32(1), A[vk], True, reduce_temp0.data, k, dtype="handle" - ) + T.tvm_thread_allreduce(T.uint32(1), A[vk], True, reduce_temp0[0], k, dtype="handle") ) with T.block("B_write_back"): T.reads([reduce_temp0[0]]) @@ -746,7 +744,7 @@ def lowered_softmax(var_A: T.handle, var_T_softmax_norm: T.handle) -> None: T.uint32(1), normal_reduce_temp0[0], True, - reduce_temp0.data, + reduce_temp0[0], ax0_1, dtype="handle", ) @@ -789,7 +787,7 @@ def lowered_softmax(var_A: T.handle, var_T_softmax_norm: T.handle) -> None: T.uint32(1), normal_reduce_temp1[0], True, - reduce_temp1.data, + reduce_temp1[0], ax0_1, dtype="handle", ) diff --git a/tests/python/unittest/test_tir_transform_lower_tvm_builtin.py b/tests/python/unittest/test_tir_transform_lower_tvm_builtin.py index 63772dea65d7..76d6bb82cce3 100644 --- a/tests/python/unittest/test_tir_transform_lower_tvm_builtin.py +++ b/tests/python/unittest/test_tir_transform_lower_tvm_builtin.py @@ -157,8 +157,8 @@ def build_tir(): Aptr[0] = packed_echo(tvm.tir.const(expected_value[0], "float32")) # return handle # let Aptr_var = testing.echo(Aptr) in Aptr_var[1] = expected_value[1] - Aptr_var = ib.let("Aptr_dup", packed_echo(Aptr.asobject())) - ib.emit(tvm.tir.Store(Aptr, tvm.tir.const(expected_value[1], "float32"), 1)) + Aptr_var = ib.let("Aptr_dup", packed_echo(Aptr.asobject().data)) + ib.emit(tvm.tir.BufferStore(Aptr, tvm.tir.const(expected_value[1], "float32"), [1])) stmt = ib.get() return tvm.IRModule.from_expr( diff --git a/tests/python/unittest/test_tir_transform_narrow_datatype.py b/tests/python/unittest/test_tir_transform_narrow_datatype.py index 667fad0317db..51c382309856 100644 --- a/tests/python/unittest/test_tir_transform_narrow_datatype.py +++ b/tests/python/unittest/test_tir_transform_narrow_datatype.py @@ -51,9 +51,9 @@ def lower_sch(sch, args, target_bits, extra_passes=None): def test_basic(): def check(m, n, target_bits, target_dtype): ib = tvm.tir.ir_builder.create() - Ab = tvm.tir.decl_buffer((m, n), name="A") + Ab = tvm.tir.decl_buffer([m * n], name="A") A = ib.buffer_ptr(Ab) - Bb = tvm.tir.decl_buffer((m, n), name="B") + Bb = tvm.tir.decl_buffer([m * n], name="B") B = ib.buffer_ptr(Bb) with ib.for_range(0, m, name="i") as i: with ib.for_range(0, n, name="j") as j: @@ -83,9 +83,9 @@ def check(m, n, target_bits, target_dtype): def test_thread_axis(): def check(m, n, target_bits, target_dtype): ib = tvm.tir.ir_builder.create() - Ab = tvm.tir.decl_buffer((m, n), name="A") + Ab = tvm.tir.decl_buffer([m * n], name="A") A = ib.buffer_ptr(Ab) - Bb = tvm.tir.decl_buffer((m, n), name="B") + Bb = tvm.tir.decl_buffer([m * n], name="B") B = ib.buffer_ptr(Bb) bx = te.thread_axis("blockIdx.x") tx = te.thread_axis("threadIdx.x") @@ -168,9 +168,9 @@ def test_slice(): def check(m, n, target_bits, target_dtype): # The index may overflow in B, while not in A ib = tvm.tir.ir_builder.create() - Ab = tvm.tir.decl_buffer((m, n), name="A") + Ab = tvm.tir.decl_buffer([m * n], name="A") A = ib.buffer_ptr(Ab) - Bb = tvm.tir.decl_buffer((m, n * 2), name="B") + Bb = tvm.tir.decl_buffer([m * n * 2], name="B") B = ib.buffer_ptr(Bb) with ib.for_range(0, m, name="i") as i: with ib.for_range(0, n, name="j") as j: @@ -242,7 +242,7 @@ def check(shape, index, target_bits, target_dtype): func = mod["main"] z = engine.lower(func, "llvm") stmt = lower_sch(z.schedule, tuple(z.inputs) + tuple(z.outputs), 32) - assert stmt.value.index.dtype == target_dtype + assert stmt.value.indices[0].dtype == target_dtype check( (const(2 ** 16, "int64"), const(2 ** 15 + 1, "int64")), diff --git a/tests/python/unittest/test_tir_transform_remove_no_op.py b/tests/python/unittest/test_tir_transform_remove_no_op.py index 8b7a16952af9..e80d46193507 100644 --- a/tests/python/unittest/test_tir_transform_remove_no_op.py +++ b/tests/python/unittest/test_tir_transform_remove_no_op.py @@ -54,7 +54,7 @@ def test_remove_no_op(): ret = tvm.tir.transform.RemoveNoOp()(mod)["main"].body assert isinstance(ret, tvm.tir.Evaluate) - store = tvm.tir.Store(Ab.data, tvm.tir.Load(dtype, Ab.data, i) + 1, i + 1) + store = tvm.tir.BufferStore(Ab, tvm.tir.BufferLoad(Ab, [i]) + 1, [i + 1]) stmt2 = tvm.tir.SeqStmt([nop(), tvm.tir.SeqStmt([store, nop()])]) mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([Ab], stmt2)) diff --git a/tests/python/unittest/test_tir_transform_renormalize_split_pattern.py b/tests/python/unittest/test_tir_transform_renormalize_split_pattern.py index eb3efd317e9c..7f60c95164a8 100644 --- a/tests/python/unittest/test_tir_transform_renormalize_split_pattern.py +++ b/tests/python/unittest/test_tir_transform_renormalize_split_pattern.py @@ -24,9 +24,12 @@ @tvm.script.ir_module class Before: @T.prim_func - def main(inputs: T.Buffer[(1, 4, 4, 512), "float32"], weight: T.Buffer[(4, 4, 512, 256), "float32"], conv2d_transpose_nhwc: T.Buffer[(1, 8, 8, 256), "float32"]) -> None: + def main(inputs: T.Buffer[(8192,), "float32"], weight: T.Buffer[(2097152,), "float32"], conv2d_transpose_nhwc: T.Buffer[(16384,), "float32"]) -> None: # function attr dict T.func_attr({"global_symbol": "main", "tir.noalias": True}) + T.preflattened_buffer(inputs, [1, 4, 4, 512], dtype="float32", data=inputs.data) + T.preflattened_buffer(weight, [4, 4, 512, 256], dtype="float32", data=weight.data) + T.preflattened_buffer(conv2d_transpose_nhwc, [1, 8, 8, 256], dtype="float32", data=conv2d_transpose_nhwc.data) # var definition threadIdx_x = T.env_thread("threadIdx.x") blockIdx_x = T.env_thread("blockIdx.x") @@ -37,24 +40,27 @@ def main(inputs: T.Buffer[(1, 4, 4, 512), "float32"], weight: T.Buffer[(4, 4, 51 weight_shared = T.allocate([4096], "float32", "shared") T.launch_thread(threadIdx_x, 32) for i2_3_init, i1_4_init, i2_4_init in T.grid(2, 2, 2): - T.store(conv2d_transpose_nhwc_local, i1_4_init * 4 + i2_3_init * 2 + i2_4_init, T.float32(0), True) + conv2d_transpose_nhwc_local[i1_4_init * 4 + i2_3_init * 2 + i2_4_init] = T.float32(0) for i6_0 in T.serial(16): for ax0_ax1_ax2_ax3_fused_0 in T.serial(24): - T.store(PadInput_shared, ax0_ax1_ax2_ax3_fused_0 * 32 + threadIdx_x, T.if_then_else(128 <= ax0_ax1_ax2_ax3_fused_0 * 32 + threadIdx_x and ax0_ax1_ax2_ax3_fused_0 * 32 + threadIdx_x < 640 and 1 <= blockIdx_x // 32 * 2 + (ax0_ax1_ax2_ax3_fused_0 * 32 + threadIdx_x) % 128 // 32 and blockIdx_x // 32 * 2 + (ax0_ax1_ax2_ax3_fused_0 * 32 + threadIdx_x) % 128 // 32 < 5, T.load("float32", inputs.data, blockIdx_x // 32 * 1024 + ax0_ax1_ax2_ax3_fused_0 * 512 + i6_0 * 32 + threadIdx_x - 2560), T.float32(0), dtype="float32"), True) + PadInput_shared[ax0_ax1_ax2_ax3_fused_0 * 32 + threadIdx_x] = T.if_then_else(128 <= ax0_ax1_ax2_ax3_fused_0 * 32 + threadIdx_x and ax0_ax1_ax2_ax3_fused_0 * 32 + threadIdx_x < 640 and 1 <= blockIdx_x // 32 * 2 + (ax0_ax1_ax2_ax3_fused_0 * 32 + threadIdx_x) % 128 // 32 and blockIdx_x // 32 * 2 + (ax0_ax1_ax2_ax3_fused_0 * 32 + threadIdx_x) % 128 // 32 < 5, inputs[blockIdx_x // 32 * 1024 + ax0_ax1_ax2_ax3_fused_0 * 512 + i6_0 * 32 + threadIdx_x - 2560], T.float32(0), dtype="float32") for ax0_ax1_ax2_ax3_fused_0 in T.serial(32): - T.store(weight_shared, T.ramp(ax0_ax1_ax2_ax3_fused_0 * 128 + threadIdx_x * 4, 1, 4), T.load("float32x4", weight.data, T.ramp((ax0_ax1_ax2_ax3_fused_0 * 128 + threadIdx_x * 4) // 256 * 131072 + i6_0 * 8192 + (ax0_ax1_ax2_ax3_fused_0 * 128 + threadIdx_x * 4) % 256 // 8 * 256 + blockIdx_x % 32 * 8 + threadIdx_x % 2 * 4, 1, 4), T.broadcast(True, 4)), T.broadcast(True, 4)) + weight_shared[T.ramp(ax0_ax1_ax2_ax3_fused_0 * 128 + threadIdx_x * 4, 1, 4)] = weight[T.ramp((ax0_ax1_ax2_ax3_fused_0 * 128 + threadIdx_x * 4) // 256 * 131072 + i6_0 * 8192 + (ax0_ax1_ax2_ax3_fused_0 * 128 + threadIdx_x * 4) % 256 // 8 * 256 + blockIdx_x % 32 * 8 + threadIdx_x % 2 * 4, 1, 4)] for i6_1, i2_3, i4_2, i5_2, i6_2, i1_4, i2_4 in T.grid(4, 2, 4, 4, 8, 2, 2): - T.store(conv2d_transpose_nhwc_local, i1_4 * 4 + i2_3 * 2 + i2_4, T.load("float32", conv2d_transpose_nhwc_local, i1_4 * 4 + i2_3 * 2 + i2_4) + T.if_then_else((i1_4 + i4_2) % 2 == 0 and (i2_4 + i5_2) % 2 == 0, T.load("float32", PadInput_shared, threadIdx_x // 8 * 128 + (i1_4 + i4_2) // 2 * 128 + (i2_4 + i5_2) // 2 * 32 + i2_3 * 32 + i6_1 * 8 + i6_2), T.float32(0), dtype="float32") * T.load("float32", weight_shared, i6_1 * 64 + i6_2 * 8 + threadIdx_x % 8 + 3840 - i5_2 * 256 - i4_2 * 1024), True) + conv2d_transpose_nhwc_local[i1_4 * 4 + i2_3 * 2 + i2_4] = conv2d_transpose_nhwc_local[i1_4 * 4 + i2_3 * 2 + i2_4] + T.if_then_else((i1_4 + i4_2) % 2 == 0 and (i2_4 + i5_2) % 2 == 0, PadInput_shared[threadIdx_x // 8 * 128 + (i1_4 + i4_2) // 2 * 128 + (i2_4 + i5_2) // 2 * 32 + i2_3 * 32 + i6_1 * 8 + i6_2], T.float32(0), dtype="float32") * weight_shared[i6_1 * 64 + i6_2 * 8 + threadIdx_x % 8 + 3840 - i5_2 * 256 - i4_2 * 1024] for ax1, ax2 in T.grid(2, 4): - T.store(conv2d_transpose_nhwc.data, threadIdx_x // 8 * 4096 + ax1 * 2048 + blockIdx_x // 32 * 1024 + ax2 * 256 + blockIdx_x % 32 * 8 + threadIdx_x % 8, T.load("float32", conv2d_transpose_nhwc_local, ax1 * 4 + ax2), True) + conv2d_transpose_nhwc[threadIdx_x // 8 * 4096 + ax1 * 2048 + blockIdx_x // 32 * 1024 + ax2 * 256 + blockIdx_x % 32 * 8 + threadIdx_x % 8] = conv2d_transpose_nhwc_local[ax1 * 4 + ax2] @tvm.script.ir_module class After: @T.prim_func - def main(inputs: T.Buffer[(1, 4, 4, 512), "float32"], weight: T.Buffer[(4, 4, 512, 256), "float32"], conv2d_transpose_nhwc: T.Buffer[(1, 8, 8, 256), "float32"]) -> None: + def main(inputs: T.Buffer[(8192,), "float32"], weight: T.Buffer[(2097152,), "float32"], conv2d_transpose_nhwc: T.Buffer[(16384,), "float32"]) -> None: # function attr dict T.func_attr({"global_symbol": "main", "tir.noalias": True}) + T.preflattened_buffer(inputs, [1, 4, 4, 512], dtype="float32", data=inputs.data) + T.preflattened_buffer(weight, [4, 4, 512, 256], dtype="float32", data=weight.data) + T.preflattened_buffer(conv2d_transpose_nhwc, [1, 8, 8, 256], dtype="float32", data=conv2d_transpose_nhwc.data) # var definition threadIdx_x = T.env_thread("threadIdx.x") blockIdx_x = T.env_thread("blockIdx.x") @@ -65,24 +71,27 @@ def main(inputs: T.Buffer[(1, 4, 4, 512), "float32"], weight: T.Buffer[(4, 4, 51 weight_shared = T.allocate([4096], "float32", "shared") T.launch_thread(threadIdx_x, 32) for i2_3_init, i1_4_init, i2_4_init in T.grid(2, 2, 2): - T.store(conv2d_transpose_nhwc_local, i1_4_init * 4 + i2_3_init * 2 + i2_4_init, T.float32(0), True) + conv2d_transpose_nhwc_local[i1_4_init * 4 + i2_3_init * 2 + i2_4_init] = T.float32(0) for i6_0 in T.serial(16): for ax0_ax1_ax2_ax3_fused_0 in T.serial(24): - T.store(PadInput_shared, ax0_ax1_ax2_ax3_fused_0 * 32 + threadIdx_x, T.if_then_else(1 <= (ax0_ax1_ax2_ax3_fused_0 + threadIdx_x // 32) // 4 and (ax0_ax1_ax2_ax3_fused_0 + threadIdx_x // 32) // 20 < 1 and 1 <= blockIdx_x // 32 * 2 + (ax0_ax1_ax2_ax3_fused_0 + threadIdx_x // 32) % 4 and (blockIdx_x // 32 * 2 + (ax0_ax1_ax2_ax3_fused_0 + threadIdx_x // 32) % 4) // 5 < 1, T.load("float32", inputs.data, blockIdx_x // 32 * 1024 + ax0_ax1_ax2_ax3_fused_0 * 512 + i6_0 * 32 + threadIdx_x - 2560), T.float32(0), dtype="float32"), True) + PadInput_shared[ax0_ax1_ax2_ax3_fused_0 * 32 + threadIdx_x] = T.if_then_else(1 <= (ax0_ax1_ax2_ax3_fused_0 + threadIdx_x // 32) // 4 and (ax0_ax1_ax2_ax3_fused_0 + threadIdx_x // 32) // 20 < 1 and 1 <= blockIdx_x // 32 * 2 + (ax0_ax1_ax2_ax3_fused_0 + threadIdx_x // 32) % 4 and (blockIdx_x // 32 * 2 + (ax0_ax1_ax2_ax3_fused_0 + threadIdx_x // 32) % 4) // 5 < 1, inputs[blockIdx_x // 32 * 1024 + ax0_ax1_ax2_ax3_fused_0 * 512 + i6_0 * 32 + threadIdx_x - 2560], T.float32(0), dtype="float32") for ax0_ax1_ax2_ax3_fused_0 in T.serial(32): - T.store(weight_shared, T.ramp(ax0_ax1_ax2_ax3_fused_0 * 128 + threadIdx_x * 4, 1, 4), T.load("float32x4", weight.data, T.ramp((ax0_ax1_ax2_ax3_fused_0 + threadIdx_x * 4 // 128) // 2 * 131072 + i6_0 * 8192 + (ax0_ax1_ax2_ax3_fused_0 * 16 + threadIdx_x * 4 // 8) % 32 * 256 + blockIdx_x % 32 * 8 + threadIdx_x % 2 * 4, 1, 4), T.broadcast(True, 4)), T.broadcast(True, 4)) + weight_shared[T.ramp(ax0_ax1_ax2_ax3_fused_0 * 128 + threadIdx_x * 4, 1, 4)] = weight[T.ramp((ax0_ax1_ax2_ax3_fused_0 + threadIdx_x * 4 // 128) // 2 * 131072 + i6_0 * 8192 + (ax0_ax1_ax2_ax3_fused_0 * 16 + threadIdx_x * 4 // 8) % 32 * 256 + blockIdx_x % 32 * 8 + threadIdx_x % 2 * 4, 1, 4)] for i6_1, i2_3, i4_2, i5_2, i6_2, i1_4, i2_4 in T.grid(4, 2, 4, 4, 8, 2, 2): - T.store(conv2d_transpose_nhwc_local, i1_4 * 4 + i2_3 * 2 + i2_4, T.load("float32", conv2d_transpose_nhwc_local, i1_4 * 4 + i2_3 * 2 + i2_4) + T.if_then_else((i1_4 + i4_2) % 2 == 0 and (i2_4 + i5_2) % 2 == 0, T.load("float32", PadInput_shared, threadIdx_x // 8 * 128 + (i1_4 + i4_2) // 2 * 128 + (i2_4 + i5_2) // 2 * 32 + i2_3 * 32 + i6_1 * 8 + i6_2), T.float32(0), dtype="float32") * T.load("float32", weight_shared, i6_1 * 64 + i6_2 * 8 + threadIdx_x % 8 + 3840 - i5_2 * 256 - i4_2 * 1024), True) + conv2d_transpose_nhwc_local[i1_4 * 4 + i2_3 * 2 + i2_4] = conv2d_transpose_nhwc_local[i1_4 * 4 + i2_3 * 2 + i2_4] + T.if_then_else((i1_4 + i4_2) % 2 == 0 and (i2_4 + i5_2) % 2 == 0, PadInput_shared[threadIdx_x // 8 * 128 + (i1_4 + i4_2) // 2 * 128 + (i2_4 + i5_2) // 2 * 32 + i2_3 * 32 + i6_1 * 8 + i6_2], T.float32(0), dtype="float32") * weight_shared[i6_1 * 64 + i6_2 * 8 + threadIdx_x % 8 + 3840 - i5_2 * 256 - i4_2 * 1024] for ax1, ax2 in T.grid(2, 4): - T.store(conv2d_transpose_nhwc.data, threadIdx_x // 8 * 4096 + ax1 * 2048 + blockIdx_x // 32 * 1024 + ax2 * 256 + blockIdx_x % 32 * 8 + threadIdx_x % 8, T.load("float32", conv2d_transpose_nhwc_local, ax1 * 4 + ax2), True) + conv2d_transpose_nhwc[threadIdx_x // 8 * 4096 + ax1 * 2048 + blockIdx_x // 32 * 1024 + ax2 * 256 + blockIdx_x % 32 * 8 + threadIdx_x % 8] = conv2d_transpose_nhwc_local[ax1 * 4 + ax2] @tvm.script.ir_module class After_simplified: @T.prim_func - def main(inputs: T.Buffer[(1, 4, 4, 512), "float32"], weight: T.Buffer[(4, 4, 512, 256), "float32"], conv2d_transpose_nhwc: T.Buffer[(1, 8, 8, 256), "float32"]) -> None: + def main(inputs: T.Buffer[(8192,), "float32"], weight: T.Buffer[(2097152,), "float32"], conv2d_transpose_nhwc: T.Buffer[(16384,), "float32"]) -> None: # function attr dict T.func_attr({"global_symbol": "main", "tir.noalias": True}) + T.preflattened_buffer(inputs, [1, 4, 4, 512], dtype="float32", data=inputs.data) + T.preflattened_buffer(weight, [4, 4, 512, 256], dtype="float32", data=weight.data) + T.preflattened_buffer(conv2d_transpose_nhwc, [1, 8, 8, 256], dtype="float32", data=conv2d_transpose_nhwc.data) # var definition threadIdx_x = T.env_thread("threadIdx.x") blockIdx_x = T.env_thread("blockIdx.x") @@ -93,23 +102,23 @@ def main(inputs: T.Buffer[(1, 4, 4, 512), "float32"], weight: T.Buffer[(4, 4, 51 weight_shared = T.allocate([4096], "float32", "shared") T.launch_thread(threadIdx_x, 32) for i2_3_init, i1_4_init, i2_4_init in T.grid(2, 2, 2): - T.store(conv2d_transpose_nhwc_local, i1_4_init * 4 + i2_3_init * 2 + i2_4_init, T.float32(0), True) + conv2d_transpose_nhwc_local[i1_4_init * 4 + i2_3_init * 2 + i2_4_init] = T.float32(0) for i6_0 in T.serial(16): for ax0_ax1_ax2_ax3_fused_0 in T.serial(24): - T.store(PadInput_shared, ax0_ax1_ax2_ax3_fused_0 * 32 + threadIdx_x, T.if_then_else(4 <= ax0_ax1_ax2_ax3_fused_0 and ax0_ax1_ax2_ax3_fused_0 < 20 and 1 <= blockIdx_x // 32 * 2 + ax0_ax1_ax2_ax3_fused_0 % 4 and blockIdx_x // 32 * 2 + ax0_ax1_ax2_ax3_fused_0 % 4 < 5, T.load("float32", inputs.data, blockIdx_x // 32 * 1024 + ax0_ax1_ax2_ax3_fused_0 * 512 + i6_0 * 32 + threadIdx_x - 2560), T.float32(0), dtype="float32"), True) + PadInput_shared[ax0_ax1_ax2_ax3_fused_0 * 32 + threadIdx_x] = T.if_then_else(4 <= ax0_ax1_ax2_ax3_fused_0 and ax0_ax1_ax2_ax3_fused_0 < 20 and 1 <= blockIdx_x // 32 * 2 + ax0_ax1_ax2_ax3_fused_0 % 4 and blockIdx_x // 32 * 2 + ax0_ax1_ax2_ax3_fused_0 % 4 < 5, inputs[blockIdx_x // 32 * 1024 + ax0_ax1_ax2_ax3_fused_0 * 512 + i6_0 * 32 + threadIdx_x - 2560], T.float32(0), dtype="float32") for ax0_ax1_ax2_ax3_fused_0 in T.serial(32): - T.store(weight_shared, T.ramp(ax0_ax1_ax2_ax3_fused_0 * 128 + threadIdx_x * 4, 1, 4), T.load("float32x4", weight.data, T.ramp(ax0_ax1_ax2_ax3_fused_0 // 2 * 131072 + i6_0 * 8192 + (ax0_ax1_ax2_ax3_fused_0 * 16 + threadIdx_x // 2) % 32 * 256 + blockIdx_x % 32 * 8 + threadIdx_x % 2 * 4, 1, 4), T.broadcast(True, 4)), T.broadcast(True, 4)) + weight_shared[T.ramp(ax0_ax1_ax2_ax3_fused_0 * 128 + threadIdx_x * 4, 1, 4)] = weight[T.ramp(ax0_ax1_ax2_ax3_fused_0 // 2 * 131072 + i6_0 * 8192 + (ax0_ax1_ax2_ax3_fused_0 * 16 + threadIdx_x // 2) % 32 * 256 + blockIdx_x % 32 * 8 + threadIdx_x % 2 * 4, 1, 4)] for i6_1, i2_3, i4_2, i5_2, i6_2, i1_4, i2_4 in T.grid(4, 2, 4, 4, 8, 2, 2): - T.store(conv2d_transpose_nhwc_local, i1_4 * 4 + i2_3 * 2 + i2_4, T.load("float32", conv2d_transpose_nhwc_local, i1_4 * 4 + i2_3 * 2 + i2_4) + T.if_then_else((i1_4 + i4_2) % 2 == 0 and (i2_4 + i5_2) % 2 == 0, T.load("float32", PadInput_shared, threadIdx_x // 8 * 128 + (i1_4 + i4_2) // 2 * 128 + (i2_4 + i5_2) // 2 * 32 + i2_3 * 32 + i6_1 * 8 + i6_2), T.float32(0), dtype="float32") * T.load("float32", weight_shared, i6_1 * 64 + i6_2 * 8 + threadIdx_x % 8 + 3840 - i5_2 * 256 - i4_2 * 1024), True) + conv2d_transpose_nhwc_local[i1_4 * 4 + i2_3 * 2 + i2_4] = conv2d_transpose_nhwc_local[i1_4 * 4 + i2_3 * 2 + i2_4] + T.if_then_else((i1_4 + i4_2) % 2 == 0 and (i2_4 + i5_2) % 2 == 0, PadInput_shared[threadIdx_x // 8 * 128 + (i1_4 + i4_2) // 2 * 128 + (i2_4 + i5_2) // 2 * 32 + i2_3 * 32 + i6_1 * 8 + i6_2], T.float32(0), dtype="float32") * weight_shared[i6_1 * 64 + i6_2 * 8 + threadIdx_x % 8 + 3840 - i5_2 * 256 - i4_2 * 1024] for ax1, ax2 in T.grid(2, 4): - T.store(conv2d_transpose_nhwc.data, threadIdx_x // 8 * 4096 + ax1 * 2048 + blockIdx_x // 32 * 1024 + ax2 * 256 + blockIdx_x % 32 * 8 + threadIdx_x % 8, T.load("float32", conv2d_transpose_nhwc_local, ax1 * 4 + ax2), True) + conv2d_transpose_nhwc[threadIdx_x // 8 * 4096 + ax1 * 2048 + blockIdx_x // 32 * 1024 + ax2 * 256 + blockIdx_x % 32 * 8 + threadIdx_x % 8] = conv2d_transpose_nhwc_local[ax1 * 4 + ax2] # pylint: enable=no-member,invalid-name,unused-variable,line-too-long,redefined-outer-name,redundant-keyword-arg # fmt: on -def tesd_renormalize_split_pattern(): - after = tvm.tir.transform.RenomalizeSplitPattern()(Before) +def test_renormalize_split_pattern(): + after = tvm.tir.transform.RenormalizeSplitPattern()(Before) tvm.ir.assert_structural_equal(after, After) after = tvm.tir.transform.Simplify()(after) tvm.ir.assert_structural_equal(after, After_simplified) diff --git a/tests/python/unittest/test_tir_transform_simplify.py b/tests/python/unittest/test_tir_transform_simplify.py index f298288fee9e..824bef4f32f9 100644 --- a/tests/python/unittest/test_tir_transform_simplify.py +++ b/tests/python/unittest/test_tir_transform_simplify.py @@ -30,7 +30,7 @@ def test_stmt_simplify(): body = tvm.tir.LetStmt(n, 10, ib.get()) mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([A, C, n], body)) body = tvm.tir.transform.Simplify()(mod)["main"].body - assert isinstance(body.body, tvm.tir.Store) + assert isinstance(body.body, tvm.tir.BufferStore) def test_thread_extent_simplify(): @@ -48,7 +48,7 @@ def test_thread_extent_simplify(): body = tvm.tir.LetStmt(n, 10, ib.get()) mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([A, C, n], body)) body = tvm.tir.transform.Simplify()(mod)["main"].body - assert isinstance(body.body.body.body, tvm.tir.Store) + assert isinstance(body.body.body.body, tvm.tir.BufferStore) def test_if_likely(): diff --git a/tests/python/unittest/test_tir_transform_storage_flatten.py b/tests/python/unittest/test_tir_transform_storage_flatten.py index a51e926155d3..8e430b035606 100644 --- a/tests/python/unittest/test_tir_transform_storage_flatten.py +++ b/tests/python/unittest/test_tir_transform_storage_flatten.py @@ -78,28 +78,25 @@ def test_flatten_storage_align(): def test_flatten_double_buffer(): - dtype = "int64" - n = 100 - m = 4 - tx = te.thread_axis("threadIdx.x") - ib = tvm.tir.ir_builder.create() - A = ib.pointer("float32", name="A") - C = ib.pointer("float32", name="C") - ib.scope_attr(tx, "thread_extent", 1) - with ib.for_range(0, n) as i: - B = ib.allocate("float32", m, name="B", scope="shared") - with ib.new_scope(): - ib.scope_attr(B.asobject(), "double_buffer_scope", 1) - with ib.for_range(0, m) as j: - B[j] = A[i * 4 + j] - with ib.for_range(0, m) as j: - C[j] = B[j] + 1 - - stmt = ib.get() - - mod = tvm.IRModule.from_expr( - tvm.tir.PrimFunc([A, C], stmt).with_attr("from_legacy_te_schedule", True) - ) + @tvm.script.ir_module + class ModFromScript: + @T.prim_func + def main(A_param: T.handle, C_param: T.handle): + A = T.match_buffer(A_param, (400,), "float32", strides=[1]) + C = T.match_buffer(C_param, (4,), "float32", strides=[1]) + T.func_attr({"from_legacy_te_schedule": True}) + threadIdx_x = T.env_thread("threadIdx.x") + T.launch_thread(threadIdx_x, 1) + for i in T.serial(0, 100): + B = T.allocate([4], "float32", scope="shared", strides=[1]) + with T.attr(B.data, "double_buffer_scope", 1): + for j in T.serial(0, 4): + B[j] = A[4 * i + j] + + for j in T.serial(0, 4): + C[j] = B[j] + 1.0 + + mod = ModFromScript with tvm.transform.PassContext(config={"tir.InjectDoubleBuffer": {"split_loop": 2}}): mod = tvm.transform.Sequential( @@ -112,10 +109,10 @@ def test_flatten_double_buffer(): stmt = mod["main"].body assert isinstance(stmt.body, tvm.tir.Allocate) - assert stmt.body.extents[0].value == 2 + assert list(stmt.body.extents) == [8] - mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([A, C], stmt).with_attr("global_symbol", "db")) - f = tvm.tir.transform.ThreadSync("shared")(mod)["db"] + mod = tvm.tir.transform.ThreadSync("shared")(mod) + f = mod["main"] count = [0] diff --git a/tests/python/unittest/test_tir_transform_unroll_loop.py b/tests/python/unittest/test_tir_transform_unroll_loop.py index b511118f8b52..6dba694e45ac 100644 --- a/tests/python/unittest/test_tir_transform_unroll_loop.py +++ b/tests/python/unittest/test_tir_transform_unroll_loop.py @@ -16,6 +16,7 @@ # under the License. import tvm from tvm import te +from tvm.script import tir as T import os @@ -90,7 +91,7 @@ def test_unroll_fake_loop(): } ): ret = tvm.tir.transform.UnrollLoop()(mod)["main"].body - assert isinstance(ret[0], tvm.tir.Store) + assert isinstance(ret[0], tvm.tir.BufferStore) def test_unroll_single_count_loops(): @@ -110,7 +111,31 @@ def test_unroll_single_count_loops(): assert ret == stmt +def test_unroll_allocations(): + @tvm.script.ir_module + class before: + @T.prim_func + def main(): + for i in T.unroll(2): + with T.allocate([16], "float32", "global") as buf: + buf[0] = 0.0 + + @tvm.script.ir_module + class expected: + @T.prim_func + def main(): + with T.allocate([16], "float32", "global") as buf1: + buf1[0] = 0.0 + with T.allocate([16], "float32", "global") as buf2: + buf2[0] = 0.0 + + after = tvm.tir.transform.UnrollLoop()(before) + + tvm.ir.assert_structural_equal(after, expected) + + if __name__ == "__main__": test_unroll_loop() test_unroll_fake_loop() test_unroll_single_count_loops() + test_unroll_allocations() diff --git a/tests/python/unittest/test_tir_transform_vectorize.py b/tests/python/unittest/test_tir_transform_vectorize.py index 1a0d84a4f807..6558de31c00b 100644 --- a/tests/python/unittest/test_tir_transform_vectorize.py +++ b/tests/python/unittest/test_tir_transform_vectorize.py @@ -35,7 +35,8 @@ def test_vectorize_loop(): assert isinstance(stmt, tvm.tir.For) assert not isinstance(stmt.body, tvm.tir.For) - assert isinstance(stmt.body.index, tvm.tir.Ramp) + assert len(stmt.body.indices) == 1 + assert isinstance(stmt.body.indices[0], tvm.tir.Ramp) assert isinstance(stmt.body.value, tvm.tir.Broadcast) @@ -55,7 +56,8 @@ def test_vectorize_vector(): assert isinstance(stmt, tvm.tir.For) assert not isinstance(stmt.body, tvm.tir.For) - assert isinstance(stmt.body.index, tvm.tir.Ramp) + assert len(stmt.body.indices) == 1 + assert isinstance(stmt.body.indices[0], tvm.tir.Ramp) assert isinstance(stmt.body.value, tvm.tir.Broadcast) @@ -76,7 +78,8 @@ def test_vectorize_with_if(): stmt = tvm.tir.transform.VectorizeLoop()(mod)["main"].body assert isinstance(stmt, tvm.tir.IfThenElse) - assert isinstance(stmt.then_case.index, tvm.tir.Ramp) + assert len(stmt.then_case.indices) == 1 + assert isinstance(stmt.then_case.indices[0], tvm.tir.Ramp) assert isinstance(stmt.then_case.value, tvm.tir.Add) assert stmt.then_case.value.dtype == "float32x4" assert isinstance(stmt.else_case, tvm.tir.For) diff --git a/tests/python/unittest/test_tir_usmp_algo.py b/tests/python/unittest/test_tir_usmp_algo.py index 1995695100cb..548fd96676a0 100644 --- a/tests/python/unittest/test_tir_usmp_algo.py +++ b/tests/python/unittest/test_tir_usmp_algo.py @@ -125,7 +125,7 @@ def test_no_pool_error(): @pytest.mark.parametrize("algorithm", ["greedy_by_size", "greedy_by_conflicts", "hill_climb"]) def test_name_based_ordering(algorithm): - """ This checks when the size and conlicts are same a stable result is generated""" + """This checks when the size and conlicts are same a stable result is generated""" def _test(): target = Target("c") @@ -298,53 +298,53 @@ class MobilenetStructure: def tvmgen_default_fused_cast_subtract(placeholder_2: T.handle, placeholder_3: T.handle, T_subtract: T.handle) -> None: # function attr dict T.func_attr({"global_symbol": "tvmgen_default_fused_cast_subtract", "tir.noalias": True}) - placeholder_4 = T.match_buffer(placeholder_2, [1, 224, 224, 3], dtype="uint8", elem_offset=0, align=128, offset_factor=1) - placeholder_5 = T.match_buffer(placeholder_3, [], dtype="int16", elem_offset=0, align=128, offset_factor=1) - T_subtract_1 = T.match_buffer(T_subtract, [1, 224, 224, 3], dtype="int16", elem_offset=0, align=128, offset_factor=1) + placeholder_4 = T.match_buffer(placeholder_2, [150528], dtype="uint8", elem_offset=0, align=128, offset_factor=1) + placeholder_5 = T.match_buffer(placeholder_3, [1], dtype="int16", elem_offset=0, align=128, offset_factor=1) + T_subtract_1 = T.match_buffer(T_subtract, [150528], dtype="int16", elem_offset=0, align=128, offset_factor=1) # body for ax0_ax1_fused_1 in T.serial(0, 224): for ax2_1, ax3_inner_1 in T.grid(224, 3): - T.store(T_subtract_1.data, (((ax0_ax1_fused_1*672) + (ax2_1*3)) + ax3_inner_1), (T.cast(T.load("uint8", placeholder_4.data, (((ax0_ax1_fused_1*672) + (ax2_1*3)) + ax3_inner_1)), "int16") - T.load("int16", placeholder_5.data, 0)), True) + T_subtract_1[(((ax0_ax1_fused_1*672) + (ax2_1*3)) + ax3_inner_1)] = (T.cast(placeholder_4[(((ax0_ax1_fused_1*672) + (ax2_1*3)) + ax3_inner_1)], "int16") - placeholder_5[0]) @T.prim_func def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast(placeholder_62: T.handle, placeholder_63: T.handle, placeholder_64: T.handle, T_cast_20: T.handle) -> None: # function attr dict T.func_attr({"global_symbol": "tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast", "tir.noalias": True}) - placeholder_65 = T.match_buffer(placeholder_62, [1, 224, 224, 3], dtype="int16", elem_offset=0, align=128, offset_factor=1) - placeholder_66 = T.match_buffer(placeholder_63, [7, 7, 3, 64], dtype="int16", elem_offset=0, align=128, offset_factor=1) - placeholder_67 = T.match_buffer(placeholder_64, [1, 1, 1, 64], dtype="int32", elem_offset=0, align=128, offset_factor=1) - T_cast_21 = T.match_buffer(T_cast_20, [1, 112, 112, 64], dtype="uint8", elem_offset=0, align=128, offset_factor=1) + placeholder_65 = T.match_buffer(placeholder_62, [150528], dtype="int16", elem_offset=0, align=128, offset_factor=1) + placeholder_66 = T.match_buffer(placeholder_63, [9408], dtype="int16", elem_offset=0, align=128, offset_factor=1) + placeholder_67 = T.match_buffer(placeholder_64, [64], dtype="int32", elem_offset=0, align=128, offset_factor=1) + T_cast_21 = T.match_buffer(T_cast_20, [802816], dtype="uint8", elem_offset=0, align=128, offset_factor=1) # body PaddedInput_7 = T.allocate([157323], "int16", "global") for i0_i1_fused_7 in T.serial(0, 229): for i2_7, i3_7 in T.grid(229, 3): - T.store(PaddedInput_7, (((i0_i1_fused_7*687) + (i2_7*3)) + i3_7), T.if_then_else(((((2 <= i0_i1_fused_7) and (i0_i1_fused_7 < 226)) and (2 <= i2_7)) and (i2_7 < 226)), T.load("int16", placeholder_65.data, ((((i0_i1_fused_7*672) + (i2_7*3)) + i3_7) - 1350)), T.int16(0), dtype="int16"), True) + PaddedInput_7[(((i0_i1_fused_7*687) + (i2_7*3)) + i3_7)] = T.if_then_else(((((2 <= i0_i1_fused_7) and (i0_i1_fused_7 < 226)) and (2 <= i2_7)) and (i2_7 < 226)), placeholder_65[((((i0_i1_fused_7*672) + (i2_7*3)) + i3_7) - 1350)], T.int16(0), dtype="int16") for ax0_ax1_fused_ax2_fused_7 in T.serial(0, 12544): Conv2dOutput_7 = T.allocate([64], "int32", "global") for ff_3 in T.serial(0, 64): - T.store(Conv2dOutput_7, ff_3, 0, True) + Conv2dOutput_7[ff_3] = 0 for ry_2, rx_2, rc_7 in T.grid(7, 7, 3): - T.store(Conv2dOutput_7, ff_3, (T.load("int32", Conv2dOutput_7, ff_3) + (T.cast(T.load("int16", PaddedInput_7, (((((T.floordiv(ax0_ax1_fused_ax2_fused_7, 112)*1374) + (ry_2*687)) + (T.floormod(ax0_ax1_fused_ax2_fused_7, 112)*6)) + (rx_2*3)) + rc_7)), "int32")*T.cast(T.load("int16", placeholder_66.data, ((((ry_2*1344) + (rx_2*192)) + (rc_7*64)) + ff_3)), "int32"))), True) + Conv2dOutput_7[ff_3] = (Conv2dOutput_7[ff_3] + (T.cast(PaddedInput_7[(((((T.floordiv(ax0_ax1_fused_ax2_fused_7, 112)*1374) + (ry_2*687)) + (T.floormod(ax0_ax1_fused_ax2_fused_7, 112)*6)) + (rx_2*3)) + rc_7)], "int32")*T.cast(placeholder_66[((((ry_2*1344) + (rx_2*192)) + (rc_7*64)) + ff_3)], "int32"))) for ax3_inner_7 in T.serial(0, 64): - T.store(T_cast_21.data, ((ax0_ax1_fused_ax2_fused_7*64) + ax3_inner_7), T.cast(T.max(T.min(T.q_multiply_shift((T.load("int32", Conv2dOutput_7, ax3_inner_7) + T.load("int32", placeholder_67.data, ax3_inner_7)), 1939887962, 31, -9, dtype="int32"), 255), 0), "uint8"), True) + T_cast_21[((ax0_ax1_fused_ax2_fused_7*64) + ax3_inner_7)] = T.cast(T.max(T.min(T.q_multiply_shift((Conv2dOutput_7[ax3_inner_7] + placeholder_67[ax3_inner_7]), 1939887962, 31, -9, dtype="int32"), 255), 0), "uint8") @T.prim_func def tvmgen_default_fused_nn_max_pool2d_cast(placeholder_28: T.handle, T_cast_6: T.handle) -> None: # function attr dict T.func_attr({"global_symbol": "tvmgen_default_fused_nn_max_pool2d_cast", "tir.noalias": True}) - placeholder_29 = T.match_buffer(placeholder_28, [1, 112, 112, 64], dtype="uint8", elem_offset=0, align=128, offset_factor=1) - T_cast_7 = T.match_buffer(T_cast_6, [1, 56, 56, 64], dtype="int16", elem_offset=0, align=128, offset_factor=1) + placeholder_29 = T.match_buffer(placeholder_28, [802816], dtype="uint8", elem_offset=0, align=128, offset_factor=1) + T_cast_7 = T.match_buffer(T_cast_6, [200704], dtype="int16", elem_offset=0, align=128, offset_factor=1) # body tensor_2 = T.allocate([200704], "uint8", "global") for ax0_ax1_fused_4 in T.serial(0, 56): for ax2_4 in T.serial(0, 56): for ax3_init in T.serial(0, 64): - T.store(tensor_2, (((ax0_ax1_fused_4*3584) + (ax2_4*64)) + ax3_init), T.uint8(0), True) + tensor_2[(((ax0_ax1_fused_4*3584) + (ax2_4*64)) + ax3_init)] = T.uint8(0) for rv0_rv1_fused_1, ax3_2 in T.grid(9, 64): - T.store(tensor_2, (((ax0_ax1_fused_4*3584) + (ax2_4*64)) + ax3_2), T.max(T.load("uint8", tensor_2, (((ax0_ax1_fused_4*3584) + (ax2_4*64)) + ax3_2)), T.if_then_else(((((ax0_ax1_fused_4*2) + T.floordiv(rv0_rv1_fused_1, 3)) < 112) and (((ax2_4*2) + T.floormod(rv0_rv1_fused_1, 3)) < 112)), T.load("uint8", placeholder_29.data, (((((ax0_ax1_fused_4*14336) + (T.floordiv(rv0_rv1_fused_1, 3)*7168)) + (ax2_4*128)) + (T.floormod(rv0_rv1_fused_1, 3)*64)) + ax3_2)), T.uint8(0), dtype="uint8")), True) + tensor_2[(((ax0_ax1_fused_4*3584) + (ax2_4*64)) + ax3_2)] = T.max(tensor_2[(((ax0_ax1_fused_4*3584) + (ax2_4*64)) + ax3_2)], T.if_then_else(((((ax0_ax1_fused_4*2) + T.floordiv(rv0_rv1_fused_1, 3)) < 112) and (((ax2_4*2) + T.floormod(rv0_rv1_fused_1, 3)) < 112)), placeholder_29[(((((ax0_ax1_fused_4*14336) + (T.floordiv(rv0_rv1_fused_1, 3)*7168)) + (ax2_4*128)) + (T.floormod(rv0_rv1_fused_1, 3)*64)) + ax3_2)], T.uint8(0), dtype="uint8")) for ax0_ax1_fused_5 in T.serial(0, 56): for ax2_5, ax3_3 in T.grid(56, 64): - T.store(T_cast_7.data, (((ax0_ax1_fused_5*3584) + (ax2_5*64)) + ax3_3), T.cast(T.load("uint8", tensor_2, (((ax0_ax1_fused_5*3584) + (ax2_5*64)) + ax3_3)), "int16"), True) + T_cast_7[(((ax0_ax1_fused_5*3584) + (ax2_5*64)) + ax3_3)] = T.cast(tensor_2[(((ax0_ax1_fused_5*3584) + (ax2_5*64)) + ax3_3)], "int16") @T.prim_func def run_model(input: T.handle, output: T.handle) -> None: @@ -355,9 +355,9 @@ def run_model(input: T.handle, output: T.handle) -> None: T.attr("default", "device_type", 1) sid_9 = T.allocate([301056], "int8", "global") sid_8 = T.allocate([802816], "int8", "global") - T.evaluate(T.call_extern("tvmgen_default_fused_cast_subtract", input, T.lookup_param("p0", dtype="handle"), sid_9, dtype="int32")) - T.evaluate(T.call_extern("tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast", sid_9, T.lookup_param("p1", dtype="handle"), T.lookup_param("p2", dtype="handle"), sid_8, dtype="int32")) - T.evaluate(T.call_extern("tvmgen_default_fused_nn_max_pool2d_cast", sid_8, output, dtype="int32")) + T.evaluate(T.call_extern("tvmgen_default_fused_cast_subtract", input, T.lookup_param("p0", dtype="handle"), sid_9.data, dtype="int32")) + T.evaluate(T.call_extern("tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast", sid_9.data, T.lookup_param("p1", dtype="handle"), T.lookup_param("p2", dtype="handle"), sid_8.data, dtype="int32")) + T.evaluate(T.call_extern("tvmgen_default_fused_nn_max_pool2d_cast", sid_8.data, output, dtype="int32")) __tvm_meta__ = None # fmt: on @@ -418,78 +418,78 @@ class ResnetStructure: def tvmgen_default_fused_cast_subtract_fixed_point_multiply_add_clip_cast_cast(placeholder: T.handle, placeholder_1: T.handle, T_cast: T.handle) -> None: # function attr dict T.func_attr({"global_symbol": "tvmgen_default_fused_cast_subtract_fixed_point_multiply_add_clip_cast_cast", "tir.noalias": True}) - placeholder_2 = T.match_buffer(placeholder, [1, 75, 75, 64], dtype="uint8") + placeholder_2 = T.match_buffer(placeholder, [360000], dtype="uint8") placeholder_3 = T.match_buffer(placeholder_1, [64], dtype="int32") - T_cast_1 = T.match_buffer(T_cast, [1, 75, 75, 64], dtype="int16") + T_cast_1 = T.match_buffer(T_cast, [360000], dtype="int16") # body for ax0_ax1_fused, ax2, ax3_outer, ax3_inner in T.grid(75, 75, 4, 16): - T.store(T_cast_1.data, ax0_ax1_fused * 4800 + ax2 * 64 + ax3_outer * 16 + ax3_inner, T.cast(T.cast(T.max(T.min(T.q_multiply_shift(T.cast(T.load("uint8", placeholder_2.data, ax0_ax1_fused * 4800 + ax2 * 64 + ax3_outer * 16 + ax3_inner), "int32") - 94, 1843157232, 31, 1, dtype="int32") + T.load("int32", placeholder_3.data, ax3_outer * 16 + ax3_inner), 255), 0), "uint8"), "int16"), True) + T_cast_1[ax0_ax1_fused * 4800 + ax2 * 64 + ax3_outer * 16 + ax3_inner] = T.cast(T.cast(T.max(T.min(T.q_multiply_shift(T.cast(placeholder_2[ax0_ax1_fused * 4800 + ax2 * 64 + ax3_outer * 16 + ax3_inner], "int32") - 94, 1843157232, 31, 1, dtype="int32") + placeholder_3[ax3_outer * 16 + ax3_inner], 255), 0), "uint8"), "int16") @T.prim_func def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast_1(placeholder_10: T.handle, placeholder_11: T.handle, placeholder_12: T.handle, T_cast_4: T.handle) -> None: # function attr dict T.func_attr({"global_symbol": "tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast_1", "tir.noalias": True}) - placeholder_13 = T.match_buffer(placeholder_10, [1, 75, 75, 64], dtype="int16") - placeholder_14 = T.match_buffer(placeholder_11, [3, 3, 64, 64], dtype="int16") - placeholder_15 = T.match_buffer(placeholder_12, [1, 1, 1, 64], dtype="int32") - T_cast_5 = T.match_buffer(T_cast_4, [1, 75, 75, 64], dtype="int16") + placeholder_13 = T.match_buffer(placeholder_10, [360000], dtype="int16") + placeholder_14 = T.match_buffer(placeholder_11, [36864], dtype="int16") + placeholder_15 = T.match_buffer(placeholder_12, [64], dtype="int32") + T_cast_5 = T.match_buffer(T_cast_4, [360000], dtype="int16") # body PaddedInput_1 = T.allocate([379456], "int16", "global") for i0_i1_fused_1, i2_1, i3_1 in T.grid(77, 77, 64): - T.store(PaddedInput_1, i0_i1_fused_1 * 4928 + i2_1 * 64 + i3_1, T.if_then_else(1 <= i0_i1_fused_1 and i0_i1_fused_1 < 76 and 1 <= i2_1 and i2_1 < 76, T.load("int16", placeholder_13.data, i0_i1_fused_1 * 4800 + i2_1 * 64 + i3_1 - 4864), T.int16(0), dtype="int16"), True) + PaddedInput_1[i0_i1_fused_1 * 4928 + i2_1 * 64 + i3_1] = T.if_then_else(1 <= i0_i1_fused_1 and i0_i1_fused_1 < 76 and 1 <= i2_1 and i2_1 < 76, placeholder_13[i0_i1_fused_1 * 4800 + i2_1 * 64 + i3_1 - 4864], T.int16(0), dtype="int16") for ax0_ax1_fused_ax2_fused_1 in T.serial(0, 5625): Conv2dOutput_1 = T.allocate([64], "int32", "global") for ff_1 in T.serial(0, 64): - T.store(Conv2dOutput_1, ff_1, 0, True) + Conv2dOutput_1[ff_1] = 0 for ry, rx, rc_1 in T.grid(3, 3, 64): - T.store(Conv2dOutput_1, ff_1, T.load("int32", Conv2dOutput_1, ff_1) + T.cast(T.load("int16", PaddedInput_1, T.floordiv(ax0_ax1_fused_ax2_fused_1, 75) * 4928 + ry * 4928 + rx * 64 + T.floormod(ax0_ax1_fused_ax2_fused_1, 75) * 64 + rc_1), "int32") * T.cast(T.load("int16", placeholder_14.data, ry * 12288 + rx * 4096 + rc_1 * 64 + ff_1), "int32"), True) + Conv2dOutput_1[ff_1] = Conv2dOutput_1[ff_1] + T.cast(PaddedInput_1[T.floordiv(ax0_ax1_fused_ax2_fused_1, 75) * 4928 + ry * 4928 + rx * 64 + T.floormod(ax0_ax1_fused_ax2_fused_1, 75) * 64 + rc_1], "int32") * T.cast(placeholder_14[ry * 12288 + rx * 4096 + rc_1 * 64 + ff_1], "int32") for ax3_inner_2 in T.serial(0, 64): - T.store(T_cast_5.data, ax0_ax1_fused_ax2_fused_1 * 64 + ax3_inner_2, T.cast(T.cast(T.max(T.min(T.q_multiply_shift(T.load("int32", Conv2dOutput_1, ax3_inner_2) + T.load("int32", placeholder_15.data, ax3_inner_2), 1608879842, 31, -7, dtype="int32"), 255), 0), "uint8"), "int16"), True) + T_cast_5[ax0_ax1_fused_ax2_fused_1 * 64 + ax3_inner_2] = T.cast(T.cast(T.max(T.min(T.q_multiply_shift(Conv2dOutput_1[ax3_inner_2] + placeholder_15[ax3_inner_2], 1608879842, 31, -7, dtype="int32"), 255), 0), "uint8"), "int16") @T.prim_func def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_add_clip_cast_cast_subtract_fixed_point_15934180698220515269_(placeholder_16: T.handle, placeholder_17: T.handle, placeholder_18: T.handle, T_add: T.handle) -> None: # function attr dict T.func_attr({"global_symbol": "tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_add_clip_cast_cast_subtract_fixed_point_15934180698220515269_", "tir.noalias": True}) - placeholder_19 = T.match_buffer(placeholder_16, [1, 75, 75, 64], dtype="int16") - placeholder_20 = T.match_buffer(placeholder_17, [1, 1, 64, 256], dtype="int16") - placeholder_21 = T.match_buffer(placeholder_18, [1, 1, 1, 256], dtype="int32") - T_add_1 = T.match_buffer(T_add, [1, 75, 75, 256], dtype="int32") + placeholder_19 = T.match_buffer(placeholder_16, [360000], dtype="int16") + placeholder_20 = T.match_buffer(placeholder_17, [16384], dtype="int16") + placeholder_21 = T.match_buffer(placeholder_18, [256], dtype="int32") + T_add_1 = T.match_buffer(T_add, [1440000], dtype="int32") # body PaddedInput_2 = T.allocate([360000], "int16", "global") for i0_i1_fused_2, i2_2, i3_2 in T.grid(75, 75, 64): - T.store(PaddedInput_2, i0_i1_fused_2 * 4800 + i2_2 * 64 + i3_2, T.load("int16", placeholder_19.data, i0_i1_fused_2 * 4800 + i2_2 * 64 + i3_2), True) + PaddedInput_2[i0_i1_fused_2 * 4800 + i2_2 * 64 + i3_2] = placeholder_19[i0_i1_fused_2 * 4800 + i2_2 * 64 + i3_2] for ax0_ax1_fused_ax2_fused_2 in T.serial(0, 5625): Conv2dOutput_2 = T.allocate([64], "int32", "global") for ax3_outer_1 in T.serial(0, 4): for ff_2 in T.serial(0, 64): - T.store(Conv2dOutput_2, ff_2, 0, True) + Conv2dOutput_2[ff_2] = 0 for rc_2 in T.serial(0, 64): - T.store(Conv2dOutput_2, ff_2, T.load("int32", Conv2dOutput_2, ff_2) + T.cast(T.load("int16", PaddedInput_2, ax0_ax1_fused_ax2_fused_2 * 64 + rc_2), "int32") * T.cast(T.load("int16", placeholder_20.data, rc_2 * 256 + ax3_outer_1 * 64 + ff_2), "int32"), True) + Conv2dOutput_2[ff_2] = Conv2dOutput_2[ff_2] + T.cast(PaddedInput_2[ax0_ax1_fused_ax2_fused_2 * 64 + rc_2], "int32") * T.cast(placeholder_20[rc_2 * 256 + ax3_outer_1 * 64 + ff_2], "int32") for ax3_inner_3 in T.serial(0, 64): - T.store(T_add_1.data, ax0_ax1_fused_ax2_fused_2 * 256 + ax3_outer_1 * 64 + ax3_inner_3, T.q_multiply_shift(T.cast(T.cast(T.max(T.min(T.q_multiply_shift(T.load("int32", Conv2dOutput_2, ax3_inner_3) + T.load("int32", placeholder_21.data, ax3_outer_1 * 64 + ax3_inner_3), 1711626602, 31, -8, dtype="int32") + 132, 255), 0), "uint8"), "int32") - 132, 2094289803, 31, -2, dtype="int32") + 136, True) + T_add_1[ax0_ax1_fused_ax2_fused_2 * 256 + ax3_outer_1 * 64 + ax3_inner_3] = T.q_multiply_shift(T.cast(T.cast(T.max(T.min(T.q_multiply_shift(Conv2dOutput_2[ax3_inner_3] + placeholder_21[ax3_outer_1 * 64 + ax3_inner_3], 1711626602, 31, -8, dtype="int32") + 132, 255), 0), "uint8"), "int32") - 132, 2094289803, 31, -2, dtype="int32") + 136 @T.prim_func def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_add_clip_cast_cast_subtract_fixed_point_4200876283395191415_(placeholder_22: T.handle, placeholder_23: T.handle, placeholder_24: T.handle, placeholder_25: T.handle, T_cast_6: T.handle) -> None: # function attr dict T.func_attr({"global_symbol": "tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_add_clip_cast_cast_subtract_fixed_point_4200876283395191415_", "tir.noalias": True}) - placeholder_29 = T.match_buffer(placeholder_22, [1, 75, 75, 64], dtype="int16") - placeholder_27 = T.match_buffer(placeholder_23, [1, 1, 64, 256], dtype="int16") - placeholder_26 = T.match_buffer(placeholder_24, [1, 1, 1, 256], dtype="int32") - placeholder_28 = T.match_buffer(placeholder_25, [1, 75, 75, 256], dtype="int32") - T_cast_7 = T.match_buffer(T_cast_6, [1, 75, 75, 256], dtype="uint8") + placeholder_29 = T.match_buffer(placeholder_22, [360000], dtype="int16") + placeholder_27 = T.match_buffer(placeholder_23, [16384], dtype="int16") + placeholder_26 = T.match_buffer(placeholder_24, [256], dtype="int32") + placeholder_28 = T.match_buffer(placeholder_25, [1440000], dtype="int32") + T_cast_7 = T.match_buffer(T_cast_6, [1440000], dtype="uint8") # body PaddedInput_3 = T.allocate([360000], "int16", "global") for i0_i1_fused_3, i2_3, i3_3 in T.grid(75, 75, 64): - T.store(PaddedInput_3, i0_i1_fused_3 * 4800 + i2_3 * 64 + i3_3, T.load("int16", placeholder_29.data, i0_i1_fused_3 * 4800 + i2_3 * 64 + i3_3), True) + PaddedInput_3[i0_i1_fused_3 * 4800 + i2_3 * 64 + i3_3] = placeholder_29[i0_i1_fused_3 * 4800 + i2_3 * 64 + i3_3] for ax0_ax1_fused_ax2_fused_3 in T.serial(0, 5625): Conv2dOutput_3 = T.allocate([64], "int32", "global") for ax3_outer_2 in T.serial(0, 4): for ff_3 in T.serial(0, 64): - T.store(Conv2dOutput_3, ff_3, 0, True) + Conv2dOutput_3[ff_3] = 0 for rc_3 in T.serial(0, 64): - T.store(Conv2dOutput_3, ff_3, T.load("int32", Conv2dOutput_3, ff_3) + T.cast(T.load("int16", PaddedInput_3, ax0_ax1_fused_ax2_fused_3 * 64 + rc_3), "int32") * T.cast(T.load("int16", placeholder_27.data, rc_3 * 256 + ax3_outer_2 * 64 + ff_3), "int32"), True) + Conv2dOutput_3[ff_3] = Conv2dOutput_3[ff_3] + T.cast(PaddedInput_3[ax0_ax1_fused_ax2_fused_3 * 64 + rc_3], "int32") * T.cast(placeholder_27[rc_3 * 256 + ax3_outer_2 * 64 + ff_3], "int32") for ax3_inner_4 in T.serial(0, 64): - T.store(T_cast_7.data, ax0_ax1_fused_ax2_fused_3 * 256 + ax3_outer_2 * 64 + ax3_inner_4, T.cast(T.max(T.min(T.q_multiply_shift(T.cast(T.cast(T.max(T.min(T.q_multiply_shift(T.load("int32", Conv2dOutput_3, ax3_inner_4) + T.load("int32", placeholder_26.data, ax3_outer_2 * 64 + ax3_inner_4), 1343014664, 31, -8, dtype="int32") + 136, 255), 0), "uint8"), "int32") - 136, 1073903788, 31, 1, dtype="int32") + T.load("int32", placeholder_28.data, ax0_ax1_fused_ax2_fused_3 * 256 + ax3_outer_2 * 64 + ax3_inner_4), 255), 0), "uint8"), True) + T_cast_7[ax0_ax1_fused_ax2_fused_3 * 256 + ax3_outer_2 * 64 + ax3_inner_4] = T.cast(T.max(T.min(T.q_multiply_shift(T.cast(T.cast(T.max(T.min(T.q_multiply_shift(Conv2dOutput_3[ax3_inner_4] + placeholder_26[ax3_outer_2 * 64 + ax3_inner_4], 1343014664, 31, -8, dtype="int32") + 136, 255), 0), "uint8"), "int32") - 136, 1073903788, 31, 1, dtype="int32") + placeholder_28[ax0_ax1_fused_ax2_fused_3 * 256 + ax3_outer_2 * 64 + ax3_inner_4], 255), 0), "uint8") @T.prim_func def tvmgen_default_run_model(input: T.handle, output: T.handle) -> None: @@ -502,32 +502,32 @@ def tvmgen_default_run_model(input: T.handle, output: T.handle) -> None: sid_6 = T.allocate([5760000], "int8", "global") sid_7 = T.allocate([720000], "int8", "global") sid_8 = T.allocate([720000], "int8", "global") - T.evaluate(T.call_extern("tvmgen_default_fused_cast_subtract_fixed_point_multiply_add_clip_cast_cast", input, T.lookup_param("p0", dtype="handle"), sid_2, dtype="int32")) - T.evaluate(T.call_extern("tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast", sid_2, T.lookup_param("p3", dtype="handle"), T.lookup_param("p4", dtype="handle"), sid_8, dtype="int32")) - T.evaluate(T.call_extern("tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast_1", sid_8, T.lookup_param("p5", dtype="handle"), T.lookup_param("p6", dtype="handle"), sid_7, dtype="int32")) - T.evaluate(T.call_extern("tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_add_clip_cast_cast_subtract_fixed_point_15934180698220515269_", sid_7, T.lookup_param("p7", dtype="handle"), T.lookup_param("p8", dtype="handle"), sid_6, dtype="int32")) - T.evaluate(T.call_extern("tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_add_clip_cast_cast_subtract_fixed_point_4200876283395191415_", sid_2, T.lookup_param("p1", dtype="handle"), T.lookup_param("p2", dtype="handle"), sid_6, output, dtype="int32")) + T.evaluate(T.call_extern("tvmgen_default_fused_cast_subtract_fixed_point_multiply_add_clip_cast_cast", input, T.lookup_param("p0", dtype="handle"), sid_2.data, dtype="int32")) + T.evaluate(T.call_extern("tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast", sid_2.data, T.lookup_param("p3", dtype="handle"), T.lookup_param("p4", dtype="handle"), sid_8.data, dtype="int32")) + T.evaluate(T.call_extern("tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast_1", sid_8.data, T.lookup_param("p5", dtype="handle"), T.lookup_param("p6", dtype="handle"), sid_7.data, dtype="int32")) + T.evaluate(T.call_extern("tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_add_clip_cast_cast_subtract_fixed_point_15934180698220515269_", sid_7.data, T.lookup_param("p7", dtype="handle"), T.lookup_param("p8", dtype="handle"), sid_6.data, dtype="int32")) + T.evaluate(T.call_extern("tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_add_clip_cast_cast_subtract_fixed_point_4200876283395191415_", sid_2.data, T.lookup_param("p1", dtype="handle"), T.lookup_param("p2", dtype="handle"), sid_6.data, output, dtype="int32")) @T.prim_func def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast(placeholder_4: T.handle, placeholder_5: T.handle, placeholder_6: T.handle, T_cast_2: T.handle) -> None: # function attr dict T.func_attr({"global_symbol": "tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast", "tir.noalias": True}) - placeholder_7 = T.match_buffer(placeholder_4, [1, 75, 75, 64], dtype="int16") - placeholder_8 = T.match_buffer(placeholder_5, [1, 1, 64, 64], dtype="int16") - placeholder_9 = T.match_buffer(placeholder_6, [1, 1, 1, 64], dtype="int32") - T_cast_3 = T.match_buffer(T_cast_2, [1, 75, 75, 64], dtype="int16") + placeholder_7 = T.match_buffer(placeholder_4, [360000], dtype="int16") + placeholder_8 = T.match_buffer(placeholder_5, [4096], dtype="int16") + placeholder_9 = T.match_buffer(placeholder_6, [64], dtype="int32") + T_cast_3 = T.match_buffer(T_cast_2, [360000], dtype="int16") # body PaddedInput = T.allocate([360000], "int16", "global") for i0_i1_fused, i2, i3 in T.grid(75, 75, 64): - T.store(PaddedInput, i0_i1_fused * 4800 + i2 * 64 + i3, T.load("int16", placeholder_7.data, i0_i1_fused * 4800 + i2 * 64 + i3), True) + PaddedInput[i0_i1_fused * 4800 + i2 * 64 + i3] = placeholder_7[i0_i1_fused * 4800 + i2 * 64 + i3] for ax0_ax1_fused_ax2_fused in T.serial(0, 5625): Conv2dOutput = T.allocate([64], "int32", "global") for ff in T.serial(0, 64): - T.store(Conv2dOutput, ff, 0, True) + Conv2dOutput[ff] = 0 for rc in T.serial(0, 64): - T.store(Conv2dOutput, ff, T.load("int32", Conv2dOutput, ff) + T.cast(T.load("int16", PaddedInput, ax0_ax1_fused_ax2_fused * 64 + rc), "int32") * T.cast(T.load("int16", placeholder_8.data, rc * 64 + ff), "int32"), True) + Conv2dOutput[ff] = Conv2dOutput[ff] + T.cast(PaddedInput[ax0_ax1_fused_ax2_fused * 64 + rc], "int32") * T.cast(placeholder_8[rc * 64 + ff], "int32") for ax3_inner_1 in T.serial(0, 64): - T.store(T_cast_3.data, ax0_ax1_fused_ax2_fused * 64 + ax3_inner_1, T.cast(T.cast(T.max(T.min(T.q_multiply_shift(T.load("int32", Conv2dOutput, ax3_inner_1) + T.load("int32", placeholder_9.data, ax3_inner_1), 1843106743, 31, -6, dtype="int32"), 255), 0), "uint8"), "int16"), True) + T_cast_3[ax0_ax1_fused_ax2_fused * 64 + ax3_inner_1] = T.cast(T.cast(T.max(T.min(T.q_multiply_shift(Conv2dOutput[ax3_inner_1] + placeholder_9[ax3_inner_1], 1843106743, 31, -6, dtype="int32"), 255), 0), "uint8"), "int16") __tvm_meta__ = None # fmt: on diff --git a/tests/python/unittest/test_tir_usmp_analysis_extract_bufferinfo.py b/tests/python/unittest/test_tir_usmp_analysis_extract_bufferinfo.py index ed8ff329ebf4..22b3d5826b3b 100644 --- a/tests/python/unittest/test_tir_usmp_analysis_extract_bufferinfo.py +++ b/tests/python/unittest/test_tir_usmp_analysis_extract_bufferinfo.py @@ -99,53 +99,53 @@ class LinearStructure: def tvmgen_default_fused_cast_subtract(placeholder_2: T.handle, placeholder_3: T.handle, T_subtract: T.handle) -> None: # function attr dict T.func_attr({"global_symbol": "tvmgen_default_fused_cast_subtract", "tir.noalias": True}) - placeholder_4 = T.match_buffer(placeholder_2, [1, 224, 224, 3], dTpe="uint8", elem_offset=0, align=128, offset_factor=1) - placeholder_5 = T.match_buffer(placeholder_3, [], dtype="int16", elem_offset=0, align=128, offset_factor=1) - T_subtract_1 = T.match_buffer(T_subtract, [1, 224, 224, 3], dtype="int16", elem_offset=0, align=128, offset_factor=1) + placeholder_4 = T.match_buffer(placeholder_2, [150528], dTpe="uint8", elem_offset=0, align=128, offset_factor=1) + placeholder_5 = T.match_buffer(placeholder_3, [1], dtype="int16", elem_offset=0, align=128, offset_factor=1) + T_subtract_1 = T.match_buffer(T_subtract, [452], dtype="int16", elem_offset=0, align=128, offset_factor=1) # body for ax0_ax1_fused_1 in T.serial(0, 224): for ax2_1, ax3_inner_1 in T.grid(224, 3): - T.store(T_subtract_1.data, (((ax0_ax1_fused_1*672) + (ax2_1*3)) + ax3_inner_1), (T.cast(T.load("uint8", placeholder_4.data, (((ax0_ax1_fused_1*672) + (ax2_1*3)) + ax3_inner_1)), "int16") - T.load("int16", placeholder_5.data, 0)), True) + T_subtract_1[(((ax0_ax1_fused_1*672) + (ax2_1*3)) + ax3_inner_1)] = (T.cast(placeholder_4[(((ax0_ax1_fused_1*672) + (ax2_1*3)) + ax3_inner_1)], "int16") - placeholder_5[0]) @T.prim_func def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast(placeholder_62: T.handle, placeholder_63: T.handle, placeholder_64: T.handle, T_cast_20: T.handle) -> None: # function attr dict T.func_attr({"global_symbol": "tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast", "tir.noalias": True}) - placeholder_65 = T.match_buffer(placeholder_62, [1, 224, 224, 3], dtype="int16", elem_offset=0, align=128, offset_factor=1) - placeholder_66 = T.match_buffer(placeholder_63, [7, 7, 3, 64], dtype="int16", elem_offset=0, align=128, offset_factor=1) - placeholder_67 = T.match_buffer(placeholder_64, [1, 1, 1, 64], dtype="int32", elem_offset=0, align=128, offset_factor=1) - T_cast_21 = T.match_buffer(T_cast_20, [1, 112, 112, 64], dtype="uint8", elem_offset=0, align=128, offset_factor=1) + placeholder_65 = T.match_buffer(placeholder_62, [150528], dtype="int16", elem_offset=0, align=128, offset_factor=1) + placeholder_66 = T.match_buffer(placeholder_63, [9408], dtype="int16", elem_offset=0, align=128, offset_factor=1) + placeholder_67 = T.match_buffer(placeholder_64, [64], dtype="int32", elem_offset=0, align=128, offset_factor=1) + T_cast_21 = T.match_buffer(T_cast_20, [289], dtype="uint8", elem_offset=0, align=128, offset_factor=1) # body PaddedInput_7 = T.allocate([157323], "int16", "global") for i0_i1_fused_7 in T.serial(0, 229): for i2_7, i3_7 in T.grid(229, 3): - T.store(PaddedInput_7, (((i0_i1_fused_7*687) + (i2_7*3)) + i3_7), T.if_then_else(((((2 <= i0_i1_fused_7) and (i0_i1_fused_7 < 226)) and (2 <= i2_7)) and (i2_7 < 226)), T.load("int16", placeholder_65.data, ((((i0_i1_fused_7*672) + (i2_7*3)) + i3_7) - 1350)), T.int16(0), dtype="int16"), True) + PaddedInput_7[(((i0_i1_fused_7*687) + (i2_7*3)) + i3_7)] = T.if_then_else(((((2 <= i0_i1_fused_7) and (i0_i1_fused_7 < 226)) and (2 <= i2_7)) and (i2_7 < 226)), placeholder_65[((((i0_i1_fused_7*672) + (i2_7*3)) + i3_7) - 1350)], T.int16(0), dtype="int16") for ax0_ax1_fused_ax2_fused_7 in T.serial(0, 12544): Conv2dOutput_7 = T.allocate([64], "int32", "global") for ff_3 in T.serial(0, 64): - T.store(Conv2dOutput_7, ff_3, 0, True) + Conv2dOutput_7[ff_3] = 0 for ry_2, rx_2, rc_7 in T.grid(7, 7, 3): - T.store(Conv2dOutput_7, ff_3, (T.load("int32", Conv2dOutput_7, ff_3) + (T.cast(T.load("int16", PaddedInput_7, (((((T.floordiv(ax0_ax1_fused_ax2_fused_7, 112)*1374) + (ry_2*687)) + (T.floormod(ax0_ax1_fused_ax2_fused_7, 112)*6)) + (rx_2*3)) + rc_7)), "int32")*T.cast(T.load("int16", placeholder_66.data, ((((ry_2*1344) + (rx_2*192)) + (rc_7*64)) + ff_3)), "int32"))), True) + Conv2dOutput_7[ff_3] = (Conv2dOutput_7[ff_3] + (T.cast(PaddedInput_7[(((((T.floordiv(ax0_ax1_fused_ax2_fused_7, 112)*1374) + (ry_2*687)) + (T.floormod(ax0_ax1_fused_ax2_fused_7, 112)*6)) + (rx_2*3)) + rc_7)], "int32")*T.cast(placeholder_66[((((ry_2*1344) + (rx_2*192)) + (rc_7*64)) + ff_3)], "int32"))) for ax3_inner_7 in T.serial(0, 64): - T.store(T_cast_21.data, ((ax0_ax1_fused_ax2_fused_7*64) + ax3_inner_7), T.cast(T.max(T.min(T.q_multiply_shift((T.load("int32", Conv2dOutput_7, ax3_inner_7) + T.load("int32", placeholder_67.data, ax3_inner_7)), 1939887962, 31, -9, dtype="int32"), 255), 0), "uint8"), True) + T_cast_21[((ax0_ax1_fused_ax2_fused_7*64) + ax3_inner_7)] = T.cast(T.max(T.min(T.q_multiply_shift((Conv2dOutput_7[ax3_inner_7] + placeholder_67[ax3_inner_7]), 1939887962, 31, -9, dtype="int32"), 255), 0), "uint8") @T.prim_func def tvmgen_default_fused_nn_max_pool2d_cast(placeholder_28: T.handle, T_cast_6: T.handle) -> None: # function attr dict T.func_attr({"global_symbol": "tvmgen_default_fused_nn_max_pool2d_cast", "tir.noalias": True}) - placeholder_29 = T.match_buffer(placeholder_28, [1, 112, 112, 64], dtype="uint8", elem_offset=0, align=128, offset_factor=1) - T_cast_7 = T.match_buffer(T_cast_6, [1, 56, 56, 64], dtype="int16", elem_offset=0, align=128, offset_factor=1) + placeholder_29 = T.match_buffer(placeholder_28, [802816], dtype="uint8", elem_offset=0, align=128, offset_factor=1) + T_cast_7 = T.match_buffer(T_cast_6, [177], dtype="int16", elem_offset=0, align=128, offset_factor=1) # body tensor_2 = T.allocate([200704], "uint8", "global") for ax0_ax1_fused_4 in T.serial(0, 56): for ax2_4 in T.serial(0, 56): for ax3_init in T.serial(0, 64): - T.store(tensor_2, (((ax0_ax1_fused_4*3584) + (ax2_4*64)) + ax3_init), T.uint8(0), True) + tensor_2[(((ax0_ax1_fused_4*3584) + (ax2_4*64)) + ax3_init)] = T.uint8(0) for rv0_rv1_fused_1, ax3_2 in T.grid(9, 64): - T.store(tensor_2, (((ax0_ax1_fused_4*3584) + (ax2_4*64)) + ax3_2), T.max(T.load("uint8", tensor_2, (((ax0_ax1_fused_4*3584) + (ax2_4*64)) + ax3_2)), T.if_then_else(((((ax0_ax1_fused_4*2) + T.floordiv(rv0_rv1_fused_1, 3)) < 112) and (((ax2_4*2) + T.floormod(rv0_rv1_fused_1, 3)) < 112)), T.load("uint8", placeholder_29.data, (((((ax0_ax1_fused_4*14336) + (T.floordiv(rv0_rv1_fused_1, 3)*7168)) + (ax2_4*128)) + (T.floormod(rv0_rv1_fused_1, 3)*64)) + ax3_2)), T.uint8(0), dtype="uint8")), True) + tensor_2[(((ax0_ax1_fused_4*3584) + (ax2_4*64)) + ax3_2)] = T.max(tensor_2[(((ax0_ax1_fused_4*3584) + (ax2_4*64)) + ax3_2)], T.if_then_else(((((ax0_ax1_fused_4*2) + T.floordiv(rv0_rv1_fused_1, 3)) < 112) and (((ax2_4*2) + T.floormod(rv0_rv1_fused_1, 3)) < 112)), placeholder_29[(((((ax0_ax1_fused_4*14336) + (T.floordiv(rv0_rv1_fused_1, 3)*7168)) + (ax2_4*128)) + (T.floormod(rv0_rv1_fused_1, 3)*64)) + ax3_2)], T.uint8(0), dtype="uint8")) for ax0_ax1_fused_5 in T.serial(0, 56): for ax2_5, ax3_3 in T.grid(56, 64): - T.store(T_cast_7.data, (((ax0_ax1_fused_5*3584) + (ax2_5*64)) + ax3_3), T.cast(T.load("uint8", tensor_2, (((ax0_ax1_fused_5*3584) + (ax2_5*64)) + ax3_3)), "int16"), True) + T_cast_7[(((ax0_ax1_fused_5*3584) + (ax2_5*64)) + ax3_3)] = T.cast(tensor_2[(((ax0_ax1_fused_5*3584) + (ax2_5*64)) + ax3_3)], "int16") @T.prim_func def run_model(input: T.handle, output: T.handle) -> None: @@ -156,9 +156,9 @@ def run_model(input: T.handle, output: T.handle) -> None: T.attr("default", "device_type", 1) sid_9 = T.allocate([301056], "int8", "global") sid_8 = T.allocate([802816], "int8", "global") - T.evaluate(T.call_extern("tvmgen_default_fused_cast_subtract", input, T.lookup_param("p0", dtype="handle"), sid_9, dtype="int32")) - T.evaluate(T.call_extern("tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast", sid_9, T.lookup_param("p1", dtype="handle"), T.lookup_param("p2", dtype="handle"), sid_8, dtype="int32")) - T.evaluate(T.call_extern("tvmgen_default_fused_nn_max_pool2d_cast", sid_8, output, dtype="int32")) + T.evaluate(T.call_extern("tvmgen_default_fused_cast_subtract", input, T.lookup_param("p0", dtype="handle"), sid_9.data, dtype="int32")) + T.evaluate(T.call_extern("tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast", sid_9.data, T.lookup_param("p1", dtype="handle"), T.lookup_param("p2", dtype="handle"), sid_8.data, dtype="int32")) + T.evaluate(T.call_extern("tvmgen_default_fused_nn_max_pool2d_cast", sid_8.data, output, dtype="int32")) __tvm_meta__ = None # fmt: on @@ -207,25 +207,25 @@ class ParallelSerialMixedForLoops: def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_1(placeholder_68: T.handle, placeholder_69: T.handle, placeholder_70: T.handle, T_cast_22: T.handle) -> None: # function attr dict T.func_attr({"global_symbol": "tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_1", "tir.noalias": True}) - placeholder_71 = T.match_buffer(placeholder_68, [1, 56, 56, 64], dtype="int16", elem_offset=0, align=128, offset_factor=1) - placeholder_72 = T.match_buffer(placeholder_69, [3, 3, 64, 192], dtype="int16", elem_offset=0, align=128, offset_factor=1) - placeholder_73 = T.match_buffer(placeholder_70, [1, 1, 1, 192], dtype="int32", elem_offset=0, align=128, offset_factor=1) - T_cast_23 = T.match_buffer(T_cast_22, [1, 56, 56, 192], dtype="uint8", elem_offset=0, align=128, offset_factor=1) + placeholder_71 = T.match_buffer(placeholder_68, [200704], dtype="int16", elem_offset=0, align=128, offset_factor=1) + placeholder_72 = T.match_buffer(placeholder_69, [110592], dtype="int16", elem_offset=0, align=128, offset_factor=1) + placeholder_73 = T.match_buffer(placeholder_70, [192], dtype="int32", elem_offset=0, align=128, offset_factor=1) + T_cast_23 = T.match_buffer(T_cast_22, [305], dtype="uint8", elem_offset=0, align=128, offset_factor=1) # body PaddedInput_8 = T.allocate([215296], "int16", "global") for i0_i1_fused_8 in T.serial(0, 58): for i2_8, i3_8 in T.grid(58, 64): - T.store(PaddedInput_8, (((i0_i1_fused_8*3712) + (i2_8*64)) + i3_8), T.if_then_else(((((1 <= i0_i1_fused_8) and (i0_i1_fused_8 < 57)) and (1 <= i2_8)) and (i2_8 < 57)), T.load("int16", placeholder_71.data, ((((i0_i1_fused_8*3584) + (i2_8*64)) + i3_8) - 3648)), T.int16(0), dtype="int16"), True) + PaddedInput_8[(((i0_i1_fused_8*3712) + (i2_8*64)) + i3_8)] = T.if_then_else(((((1 <= i0_i1_fused_8) and (i0_i1_fused_8 < 57)) and (1 <= i2_8)) and (i2_8 < 57)), placeholder_71[((((i0_i1_fused_8*3584) + (i2_8*64)) + i3_8) - 3648)], T.int16(0), dtype="int16") for ax0_ax1_fused_ax2_fused_8 in T.parallel(0, 3136): dummy_allocate = T.allocate([1], "int32", "global") for ax3_outer_4 in T.serial(0, 3): Conv2dOutput_8 = T.allocate([64], "int32", "global") for ff_4 in T.serial(0, 64): - T.store(Conv2dOutput_8, ff_4, 0, True) + Conv2dOutput_8[ff_4] = 0 for ry_3, rx_3, rc_8 in T.grid(3, 3, 64): - T.store(Conv2dOutput_8, ff_4, (T.load("int32", Conv2dOutput_8, ff_4) + (T.cast(T.load("int16", PaddedInput_8, (((((T.floordiv(ax0_ax1_fused_ax2_fused_8, 56)*3712) + (ry_3*3712)) + (rx_3*64)) + (T.floormod(ax0_ax1_fused_ax2_fused_8, 56)*64)) + rc_8)), "int32")*T.cast(T.load("int16", placeholder_72.data, (((((ry_3*36864) + (rx_3*12288)) + (rc_8*192)) + (ax3_outer_4*64)) + ff_4)), "int32"))), True) + Conv2dOutput_8[ff_4] = (Conv2dOutput_8[ff_4] + (T.cast(PaddedInput_8[(((((T.floordiv(ax0_ax1_fused_ax2_fused_8, 56)*3712) + (ry_3*3712)) + (rx_3*64)) + (T.floormod(ax0_ax1_fused_ax2_fused_8, 56)*64)) + rc_8)], "int32")*T.cast(placeholder_72[(((((ry_3*36864) + (rx_3*12288)) + (rc_8*192)) + (ax3_outer_4*64)) + ff_4)], "int32"))) for ax3_inner_8 in T.serial(0, 64): - T.store(T_cast_23.data, (((ax0_ax1_fused_ax2_fused_8*192) + (ax3_outer_4*64)) + ax3_inner_8), T.cast(T.max(T.min(T.q_multiply_shift((T.load("int32", Conv2dOutput_8, ax3_inner_8) + T.load("int32", placeholder_73.data, ((ax3_outer_4*64) + ax3_inner_8))), 1139793473, 31, -6, dtype="int32"), 255), 0), "uint8"), True) + T_cast_23[(((ax0_ax1_fused_ax2_fused_8*192) + (ax3_outer_4*64)) + ax3_inner_8)] = T.cast(T.max(T.min(T.q_multiply_shift((Conv2dOutput_8[ax3_inner_8] + placeholder_73[((ax3_outer_4*64) + ax3_inner_8)]), 1139793473, 31, -6, dtype="int32"), 255), 0), "uint8") @T.prim_func def run_model(input: T.handle, output: T.handle) -> None: @@ -248,25 +248,25 @@ class AllSerialForLoops: def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_1(placeholder_68: T.handle, placeholder_69: T.handle, placeholder_70: T.handle, T_cast_22: T.handle) -> None: # function attr dict T.func_attr({"global_symbol": "tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_1", "tir.noalias": True}) - placeholder_71 = T.match_buffer(placeholder_68, [1, 56, 56, 64], dtype="int16", elem_offset=0, align=128, offset_factor=1) - placeholder_72 = T.match_buffer(placeholder_69, [3, 3, 64, 192], dtype="int16", elem_offset=0, align=128, offset_factor=1) - placeholder_73 = T.match_buffer(placeholder_70, [1, 1, 1, 192], dtype="int32", elem_offset=0, align=128, offset_factor=1) - T_cast_23 = T.match_buffer(T_cast_22, [1, 56, 56, 192], dtype="uint8", elem_offset=0, align=128, offset_factor=1) + placeholder_71 = T.match_buffer(placeholder_68, [200704], dtype="int16", elem_offset=0, align=128, offset_factor=1) + placeholder_72 = T.match_buffer(placeholder_69, [110592], dtype="int16", elem_offset=0, align=128, offset_factor=1) + placeholder_73 = T.match_buffer(placeholder_70, [192], dtype="int32", elem_offset=0, align=128, offset_factor=1) + T_cast_23 = T.match_buffer(T_cast_22, [305], dtype="uint8", elem_offset=0, align=128, offset_factor=1) # body PaddedInput_8 = T.allocate([215296], "int16", "global") for i0_i1_fused_8 in T.serial(0, 58): for i2_8, i3_8 in T.grid(58, 64): - T.store(PaddedInput_8, (((i0_i1_fused_8*3712) + (i2_8*64)) + i3_8), T.if_then_else(((((1 <= i0_i1_fused_8) and (i0_i1_fused_8 < 57)) and (1 <= i2_8)) and (i2_8 < 57)), T.load("int16", placeholder_71.data, ((((i0_i1_fused_8*3584) + (i2_8*64)) + i3_8) - 3648)), T.int16(0), dtype="int16"), True) + PaddedInput_8[(((i0_i1_fused_8*3712) + (i2_8*64)) + i3_8)] = T.if_then_else(((((1 <= i0_i1_fused_8) and (i0_i1_fused_8 < 57)) and (1 <= i2_8)) and (i2_8 < 57)), placeholder_71[((((i0_i1_fused_8*3584) + (i2_8*64)) + i3_8) - 3648)], T.int16(0), dtype="int16") for ax0_ax1_fused_ax2_fused_8 in T.serial(0, 3136): dummy_allocate = T.allocate([1], "int32", "global") for ax3_outer_4 in T.serial(0, 3): Conv2dOutput_8 = T.allocate([64], "int32", "global") for ff_4 in T.serial(0, 64): - T.store(Conv2dOutput_8, ff_4, 0, True) + Conv2dOutput_8[ff_4] = 0 for ry_3, rx_3, rc_8 in T.grid(3, 3, 64): - T.store(Conv2dOutput_8, ff_4, (T.load("int32", Conv2dOutput_8, ff_4) + (T.cast(T.load("int16", PaddedInput_8, (((((T.floordiv(ax0_ax1_fused_ax2_fused_8, 56)*3712) + (ry_3*3712)) + (rx_3*64)) + (T.floormod(ax0_ax1_fused_ax2_fused_8, 56)*64)) + rc_8)), "int32")*T.cast(T.load("int16", placeholder_72.data, (((((ry_3*36864) + (rx_3*12288)) + (rc_8*192)) + (ax3_outer_4*64)) + ff_4)), "int32"))), True) + Conv2dOutput_8[ff_4] = (Conv2dOutput_8[ff_4] + (T.cast(PaddedInput_8[(((((T.floordiv(ax0_ax1_fused_ax2_fused_8, 56)*3712) + (ry_3*3712)) + (rx_3*64)) + (T.floormod(ax0_ax1_fused_ax2_fused_8, 56)*64)) + rc_8)], "int32")*T.cast(placeholder_72[(((((ry_3*36864) + (rx_3*12288)) + (rc_8*192)) + (ax3_outer_4*64)) + ff_4)], "int32"))) for ax3_inner_8 in T.serial(0, 64): - T.store(T_cast_23.data, (((ax0_ax1_fused_ax2_fused_8*192) + (ax3_outer_4*64)) + ax3_inner_8), T.cast(T.max(T.min(T.q_multiply_shift((T.load("int32", Conv2dOutput_8, ax3_inner_8) + T.load("int32", placeholder_73.data, ((ax3_outer_4*64) + ax3_inner_8))), 1139793473, 31, -6, dtype="int32"), 255), 0), "uint8"), True) + T_cast_23[(((ax0_ax1_fused_ax2_fused_8*192) + (ax3_outer_4*64)) + ax3_inner_8)] = T.cast(T.max(T.min(T.q_multiply_shift((Conv2dOutput_8[ax3_inner_8] + placeholder_73[((ax3_outer_4*64) + ax3_inner_8)]), 1139793473, 31, -6, dtype="int32"), 255), 0), "uint8") @T.prim_func def run_model(input: T.handle, output: T.handle) -> None: @@ -330,284 +330,284 @@ class InceptionStructure: def tvmgen_default_fused_nn_max_pool2d(placeholder: T.handle, tensor: T.handle) -> None: # function attr dict T.func_attr({"global_symbol": "tvmgen_default_fused_nn_max_pool2d", "tir.noalias": True}) - placeholder_1 = T.match_buffer(placeholder, [1, 56, 56, 192], dtype="uint8", elem_offset=0, align=128, offset_factor=1) - tensor_1 = T.match_buffer(tensor, [1, 28, 28, 192], dtype="uint8", elem_offset=0, align=128, offset_factor=1) + placeholder_1 = T.match_buffer(placeholder, [602112], dtype="uint8", elem_offset=0, align=128, offset_factor=1) + tensor_1 = T.match_buffer(tensor, [249], dtype="uint8", elem_offset=0, align=128, offset_factor=1) # body for ax0_ax1_fused in T.serial(0, 28): for ax2 in T.serial(0, 28): for ax3_outer_init, ax3_inner_init in T.grid(3, 64): - T.store(tensor_1.data, ((((ax0_ax1_fused*5376) + (ax2*192)) + (ax3_outer_init*64)) + ax3_inner_init), T.uint8(0), True) + tensor_1[((((ax0_ax1_fused*5376) + (ax2*192)) + (ax3_outer_init*64)) + ax3_inner_init)] = T.uint8(0) for rv0_rv1_fused, ax3_outer, ax3_inner in T.grid(9, 3, 64): - T.store(tensor_1.data, ((((ax0_ax1_fused*5376) + (ax2*192)) + (ax3_outer*64)) + ax3_inner), T.max(T.load("uint8", tensor_1.data, ((((ax0_ax1_fused*5376) + (ax2*192)) + (ax3_outer*64)) + ax3_inner)), T.if_then_else(((((ax0_ax1_fused*2) + T.floordiv(rv0_rv1_fused, 3)) < 56) and (((ax2*2) + T.floormod(rv0_rv1_fused, 3)) < 56)), T.load("uint8", placeholder_1.data, ((((((ax0_ax1_fused*21504) + (T.floordiv(rv0_rv1_fused, 3)*10752)) + (ax2*384)) + (T.floormod(rv0_rv1_fused, 3)*192)) + (ax3_outer*64)) + ax3_inner)), T.uint8(0), dtype="uint8")), True) + tensor_1[((((ax0_ax1_fused*5376) + (ax2*192)) + (ax3_outer*64)) + ax3_inner)] = T.max(tensor_1[((((ax0_ax1_fused*5376) + (ax2*192)) + (ax3_outer*64)) + ax3_inner)], T.if_then_else(((((ax0_ax1_fused*2) + T.floordiv(rv0_rv1_fused, 3)) < 56) and (((ax2*2) + T.floormod(rv0_rv1_fused, 3)) < 56)), placeholder_1[((((((ax0_ax1_fused*21504) + (T.floordiv(rv0_rv1_fused, 3)*10752)) + (ax2*384)) + (T.floormod(rv0_rv1_fused, 3)*192)) + (ax3_outer*64)) + ax3_inner)], T.uint8(0), dtype="uint8")) @T.prim_func def tvmgen_default_fused_cast_subtract(placeholder_2: T.handle, placeholder_3: T.handle, T_subtract: T.handle) -> None: # function attr dict T.func_attr({"global_symbol": "tvmgen_default_fused_cast_subtract", "tir.noalias": True}) - placeholder_4 = T.match_buffer(placeholder_2, [1, 224, 224, 3], dtype="uint8", elem_offset=0, align=128, offset_factor=1) - placeholder_5 = T.match_buffer(placeholder_3, [], dtype="int16", elem_offset=0, align=128, offset_factor=1) - T_subtract_1 = T.match_buffer(T_subtract, [1, 224, 224, 3], dtype="int16", elem_offset=0, align=128, offset_factor=1) + placeholder_4 = T.match_buffer(placeholder_2, [150528], dtype="uint8", elem_offset=0, align=128, offset_factor=1) + placeholder_5 = T.match_buffer(placeholder_3, [1], dtype="int16", elem_offset=0, align=128, offset_factor=1) + T_subtract_1 = T.match_buffer(T_subtract, [452], dtype="int16", elem_offset=0, align=128, offset_factor=1) # body for ax0_ax1_fused_1 in T.serial(0, 224): for ax2_1, ax3_inner_1 in T.grid(224, 3): - T.store(T_subtract_1.data, (((ax0_ax1_fused_1*672) + (ax2_1*3)) + ax3_inner_1), (T.cast(T.load("uint8", placeholder_4.data, (((ax0_ax1_fused_1*672) + (ax2_1*3)) + ax3_inner_1)), "int16") - T.load("int16", placeholder_5.data, 0)), True) + T_subtract_1[(((ax0_ax1_fused_1*672) + (ax2_1*3)) + ax3_inner_1)] = (T.cast(placeholder_4[(((ax0_ax1_fused_1*672) + (ax2_1*3)) + ax3_inner_1)], "int16") - placeholder_5[0]) @T.prim_func def tvmgen_default_fused_cast(placeholder_6: T.handle, T_cast: T.handle) -> None: # function attr dict T.func_attr({"global_symbol": "tvmgen_default_fused_cast", "tir.noalias": True}) - placeholder_7 = T.match_buffer(placeholder_6, [1, 28, 28, 192], dtype="uint8", elem_offset=0, align=128, offset_factor=1) - T_cast_1 = T.match_buffer(T_cast, [1, 28, 28, 192], dtype="int16", elem_offset=0, align=128, offset_factor=1) + placeholder_7 = T.match_buffer(placeholder_6, [150528], dtype="uint8", elem_offset=0, align=128, offset_factor=1) + T_cast_1 = T.match_buffer(T_cast, [249], dtype="int16", elem_offset=0, align=128, offset_factor=1) # body for ax0_ax1_fused_2 in T.serial(0, 28): for ax2_2, ax3_outer_1, ax3_inner_2 in T.grid(28, 12, 16): - T.store(T_cast_1.data, ((((ax0_ax1_fused_2*5376) + (ax2_2*192)) + (ax3_outer_1*16)) + ax3_inner_2), T.cast(T.load("uint8", placeholder_7.data, ((((ax0_ax1_fused_2*5376) + (ax2_2*192)) + (ax3_outer_1*16)) + ax3_inner_2)), "int16"), True) + T_cast_1[((((ax0_ax1_fused_2*5376) + (ax2_2*192)) + (ax3_outer_1*16)) + ax3_inner_2)] = T.cast(placeholder_7[((((ax0_ax1_fused_2*5376) + (ax2_2*192)) + (ax3_outer_1*16)) + ax3_inner_2)], "int16") @T.prim_func def tvmgen_default_fused_concatenate(placeholder_8: T.handle, placeholder_9: T.handle, placeholder_10: T.handle, placeholder_11: T.handle, T_concat: T.handle) -> None: # function attr dict T.func_attr({"global_symbol": "tvmgen_default_fused_concatenate", "tir.noalias": True}) - placeholder_12 = T.match_buffer(placeholder_8, [1, 28, 28, 64], dtype="uint8", elem_offset=0, align=128, offset_factor=1) - T_concat_1 = T.match_buffer(T_concat, [1, 28, 28, 256], dtype="uint8", elem_offset=0, align=128, offset_factor=1) - placeholder_13 = T.match_buffer(placeholder_9, [1, 28, 28, 128], dtype="uint8", elem_offset=0, align=128, offset_factor=1) - placeholder_14 = T.match_buffer(placeholder_11, [1, 28, 28, 32], dtype="uint8", elem_offset=0, align=128, offset_factor=1) - placeholder_15 = T.match_buffer(placeholder_10, [1, 28, 28, 32], dtype="uint8", elem_offset=0, align=128, offset_factor=1) + placeholder_12 = T.match_buffer(placeholder_8, [50176], dtype="uint8", elem_offset=0, align=128, offset_factor=1) + T_concat_1 = T.match_buffer(T_concat, [313], dtype="uint8", elem_offset=0, align=128, offset_factor=1) + placeholder_13 = T.match_buffer(placeholder_9, [100352], dtype="uint8", elem_offset=0, align=128, offset_factor=1) + placeholder_14 = T.match_buffer(placeholder_11, [25088], dtype="uint8", elem_offset=0, align=128, offset_factor=1) + placeholder_15 = T.match_buffer(placeholder_10, [25088], dtype="uint8", elem_offset=0, align=128, offset_factor=1) # body for ax0_ax1_fused_3 in T.serial(0, 28): for ax2_3, ax3 in T.grid(28, 256): - T.store(T_concat_1.data, (((ax0_ax1_fused_3*7168) + (ax2_3*256)) + ax3), T.if_then_else((224 <= ax3), T.load("uint8", placeholder_14.data, ((((ax0_ax1_fused_3*896) + (ax2_3*32)) + ax3) - 224)), T.if_then_else((192 <= ax3), T.load("uint8", placeholder_15.data, ((((ax0_ax1_fused_3*896) + (ax2_3*32)) + ax3) - 192)), T.if_then_else((64 <= ax3), T.load("uint8", placeholder_13.data, ((((ax0_ax1_fused_3*3584) + (ax2_3*128)) + ax3) - 64)), T.load("uint8", placeholder_12.data, (((ax0_ax1_fused_3*1792) + (ax2_3*64)) + ax3)), dtype="uint8"), dtype="uint8"), dtype="uint8"), True) + T_concat_1[(((ax0_ax1_fused_3*7168) + (ax2_3*256)) + ax3)] = T.if_then_else((224 <= ax3), placeholder_14[((((ax0_ax1_fused_3*896) + (ax2_3*32)) + ax3) - 224)], T.if_then_else((192 <= ax3), placeholder_15[((((ax0_ax1_fused_3*896) + (ax2_3*32)) + ax3) - 192)], T.if_then_else((64 <= ax3), placeholder_13[((((ax0_ax1_fused_3*3584) + (ax2_3*128)) + ax3) - 64)], placeholder_12[(((ax0_ax1_fused_3*1792) + (ax2_3*64)) + ax3)], dtype="uint8"), dtype="uint8"), dtype="uint8") @T.prim_func def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast(placeholder_16: T.handle, placeholder_17: T.handle, placeholder_18: T.handle, T_cast_2: T.handle) -> None: # function attr dict T.func_attr({"global_symbol": "tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast", "tir.noalias": True}) - placeholder_19 = T.match_buffer(placeholder_16, [1, 56, 56, 64], dtype="int16", elem_offset=0, align=128, offset_factor=1) - placeholder_20 = T.match_buffer(placeholder_17, [1, 1, 64, 64], dtype="int16", elem_offset=0, align=128, offset_factor=1) - placeholder_21 = T.match_buffer(placeholder_18, [1, 1, 1, 64], dtype="int32", elem_offset=0, align=128, offset_factor=1) - T_cast_3 = T.match_buffer(T_cast_2, [1, 56, 56, 64], dtype="int16", elem_offset=0, align=128, offset_factor=1) + placeholder_19 = T.match_buffer(placeholder_16, [200704], dtype="int16", elem_offset=0, align=128, offset_factor=1) + placeholder_20 = T.match_buffer(placeholder_17, [4096], dtype="int16", elem_offset=0, align=128, offset_factor=1) + placeholder_21 = T.match_buffer(placeholder_18, [64], dtype="int32", elem_offset=0, align=128, offset_factor=1) + T_cast_3 = T.match_buffer(T_cast_2, [177], dtype="int16", elem_offset=0, align=128, offset_factor=1) # body PaddedInput = T.allocate([200704], "int16", "global") for i0_i1_fused in T.serial(0, 56): for i2, i3 in T.grid(56, 64): - T.store(PaddedInput, (((i0_i1_fused*3584) + (i2*64)) + i3), T.load("int16", placeholder_19.data, (((i0_i1_fused*3584) + (i2*64)) + i3)), True) + PaddedInput[(((i0_i1_fused*3584) + (i2*64)) + i3)] = placeholder_19[(((i0_i1_fused*3584) + (i2*64)) + i3)] for ax0_ax1_fused_ax2_fused in T.serial(0, 3136): Conv2dOutput = T.allocate([64], "int32", "global") for ff in T.serial(0, 64): - T.store(Conv2dOutput, ff, 0, True) + Conv2dOutput[ff] = 0 for rc in T.serial(0, 64): - T.store(Conv2dOutput, ff, (T.load("int32", Conv2dOutput, ff) + (T.cast(T.load("int16", PaddedInput, ((ax0_ax1_fused_ax2_fused*64) + rc)), "int32")*T.cast(T.load("int16", placeholder_20.data, ((rc*64) + ff)), "int32"))), True) + Conv2dOutput[ff] = (Conv2dOutput[ff] + (T.cast(PaddedInput[((ax0_ax1_fused_ax2_fused*64) + rc)], "int32")*T.cast(placeholder_20[((rc*64) + ff)], "int32"))) for ax3_inner_3 in T.serial(0, 64): - T.store(T_cast_3.data, ((ax0_ax1_fused_ax2_fused*64) + ax3_inner_3), T.cast(T.cast(T.max(T.min(T.q_multiply_shift((T.load("int32", Conv2dOutput, ax3_inner_3) + T.load("int32", placeholder_21.data, ax3_inner_3)), 1191576922, 31, -4, dtype="int32"), 255), 0), "uint8"), "int16"), True) + T_cast_3[((ax0_ax1_fused_ax2_fused*64) + ax3_inner_3)] = T.cast(T.cast(T.max(T.min(T.q_multiply_shift((Conv2dOutput[ax3_inner_3] + placeholder_21[ax3_inner_3]), 1191576922, 31, -4, dtype="int32"), 255), 0), "uint8"), "int16") @T.prim_func def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast_1(placeholder_22: T.handle, placeholder_23: T.handle, placeholder_24: T.handle, T_cast_4: T.handle) -> None: # function attr dict T.func_attr({"global_symbol": "tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast_1", "tir.noalias": True}) - placeholder_25 = T.match_buffer(placeholder_22, [1, 28, 28, 192], dtype="int16", elem_offset=0, align=128, offset_factor=1) - placeholder_26 = T.match_buffer(placeholder_23, [1, 1, 192, 96], dtype="int16", elem_offset=0, align=128, offset_factor=1) - placeholder_27 = T.match_buffer(placeholder_24, [1, 1, 1, 96], dtype="int32", elem_offset=0, align=128, offset_factor=1) - T_cast_5 = T.match_buffer(T_cast_4, [1, 28, 28, 96], dtype="int16", elem_offset=0, align=128, offset_factor=1) + placeholder_25 = T.match_buffer(placeholder_22, [150528], dtype="int16", elem_offset=0, align=128, offset_factor=1) + placeholder_26 = T.match_buffer(placeholder_23, [18432], dtype="int16", elem_offset=0, align=128, offset_factor=1) + placeholder_27 = T.match_buffer(placeholder_24, [96], dtype="int32", elem_offset=0, align=128, offset_factor=1) + T_cast_5 = T.match_buffer(T_cast_4, [153], dtype="int16", elem_offset=0, align=128, offset_factor=1) # body PaddedInput_1 = T.allocate([150528], "int16", "global") for i0_i1_fused_1 in T.serial(0, 28): for i2_1, i3_1 in T.grid(28, 192): - T.store(PaddedInput_1, (((i0_i1_fused_1*5376) + (i2_1*192)) + i3_1), T.load("int16", placeholder_25.data, (((i0_i1_fused_1*5376) + (i2_1*192)) + i3_1)), True) + PaddedInput_1[(((i0_i1_fused_1*5376) + (i2_1*192)) + i3_1)] = placeholder_25[(((i0_i1_fused_1*5376) + (i2_1*192)) + i3_1)] for ax0_ax1_fused_ax2_fused_1 in T.serial(0, 784): Conv2dOutput_1 = T.allocate([1], "int32", "global") for ax3_1 in T.serial(0, 96): - T.store(Conv2dOutput_1, 0, 0, True) + Conv2dOutput_1[0] = 0 for rc_1 in T.serial(0, 192): - T.store(Conv2dOutput_1, 0, (T.load("int32", Conv2dOutput_1, 0) + (T.cast(T.load("int16", PaddedInput_1, ((ax0_ax1_fused_ax2_fused_1*192) + rc_1)), "int32")*T.cast(T.load("int16", placeholder_26.data, ((rc_1*96) + ax3_1)), "int32"))), True) - T.store(T_cast_5.data, ((ax0_ax1_fused_ax2_fused_1*96) + ax3_1), T.cast(T.cast(T.max(T.min(T.q_multiply_shift((T.load("int32", Conv2dOutput_1, 0) + T.load("int32", placeholder_27.data, ax3_1)), 1201322342, 31, -6, dtype="int32"), 255), 0), "uint8"), "int16"), True) + Conv2dOutput_1[0] = (Conv2dOutput_1[0] + (T.cast(PaddedInput_1[((ax0_ax1_fused_ax2_fused_1*192) + rc_1)], "int32")*T.cast(placeholder_26[((rc_1*96) + ax3_1)], "int32"))) + T_cast_5[((ax0_ax1_fused_ax2_fused_1*96) + ax3_1)] = T.cast(T.cast(T.max(T.min(T.q_multiply_shift((Conv2dOutput_1[0] + placeholder_27[ax3_1]), 1201322342, 31, -6, dtype="int32"), 255), 0), "uint8"), "int16") @T.prim_func def tvmgen_default_fused_nn_max_pool2d_cast(placeholder_28: T.handle, T_cast_6: T.handle) -> None: # function attr dict T.func_attr({"global_symbol": "tvmgen_default_fused_nn_max_pool2d_cast", "tir.noalias": True}) - placeholder_29 = T.match_buffer(placeholder_28, [1, 112, 112, 64], dtype="uint8", elem_offset=0, align=128, offset_factor=1) - T_cast_7 = T.match_buffer(T_cast_6, [1, 56, 56, 64], dtype="int16", elem_offset=0, align=128, offset_factor=1) + placeholder_29 = T.match_buffer(placeholder_28, [802816], dtype="uint8", elem_offset=0, align=128, offset_factor=1) + T_cast_7 = T.match_buffer(T_cast_6, [177], dtype="int16", elem_offset=0, align=128, offset_factor=1) # body tensor_2 = T.allocate([200704], "uint8", "global") for ax0_ax1_fused_4 in T.serial(0, 56): for ax2_4 in T.serial(0, 56): for ax3_init in T.serial(0, 64): - T.store(tensor_2, (((ax0_ax1_fused_4*3584) + (ax2_4*64)) + ax3_init), T.uint8(0), True) + tensor_2[(((ax0_ax1_fused_4*3584) + (ax2_4*64)) + ax3_init)] = T.uint8(0) for rv0_rv1_fused_1, ax3_2 in T.grid(9, 64): - T.store(tensor_2, (((ax0_ax1_fused_4*3584) + (ax2_4*64)) + ax3_2), T.max(T.load("uint8", tensor_2, (((ax0_ax1_fused_4*3584) + (ax2_4*64)) + ax3_2)), T.if_then_else(((((ax0_ax1_fused_4*2) + T.floordiv(rv0_rv1_fused_1, 3)) < 112) and (((ax2_4*2) + T.floormod(rv0_rv1_fused_1, 3)) < 112)), T.load("uint8", placeholder_29.data, (((((ax0_ax1_fused_4*14336) + (T.floordiv(rv0_rv1_fused_1, 3)*7168)) + (ax2_4*128)) + (T.floormod(rv0_rv1_fused_1, 3)*64)) + ax3_2)), T.uint8(0), dtype="uint8")), True) + tensor_2[(((ax0_ax1_fused_4*3584) + (ax2_4*64)) + ax3_2)] = T.max(tensor_2[(((ax0_ax1_fused_4*3584) + (ax2_4*64)) + ax3_2)], T.if_then_else(((((ax0_ax1_fused_4*2) + T.floordiv(rv0_rv1_fused_1, 3)) < 112) and (((ax2_4*2) + T.floormod(rv0_rv1_fused_1, 3)) < 112)), placeholder_29[(((((ax0_ax1_fused_4*14336) + (T.floordiv(rv0_rv1_fused_1, 3)*7168)) + (ax2_4*128)) + (T.floormod(rv0_rv1_fused_1, 3)*64)) + ax3_2)], T.uint8(0), dtype="uint8")) for ax0_ax1_fused_5 in T.serial(0, 56): for ax2_5, ax3_3 in T.grid(56, 64): - T.store(T_cast_7.data, (((ax0_ax1_fused_5*3584) + (ax2_5*64)) + ax3_3), T.cast(T.load("uint8", tensor_2, (((ax0_ax1_fused_5*3584) + (ax2_5*64)) + ax3_3)), "int16"), True) + T_cast_7[(((ax0_ax1_fused_5*3584) + (ax2_5*64)) + ax3_3)] = T.cast(tensor_2[(((ax0_ax1_fused_5*3584) + (ax2_5*64)) + ax3_3)], "int16") @T.prim_func def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_2(placeholder_30: T.handle, placeholder_31: T.handle, placeholder_32: T.handle, T_cast_8: T.handle) -> None: # function attr dict T.func_attr({"global_symbol": "tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_2", "tir.noalias": True}) - placeholder_33 = T.match_buffer(placeholder_30, [1, 28, 28, 192], dtype="int16", elem_offset=0, align=128, offset_factor=1) - placeholder_34 = T.match_buffer(placeholder_31, [1, 1, 192, 64], dtype="int16", elem_offset=0, align=128, offset_factor=1) - placeholder_35 = T.match_buffer(placeholder_32, [1, 1, 1, 64], dtype="int32", elem_offset=0, align=128, offset_factor=1) - T_cast_9 = T.match_buffer(T_cast_8, [1, 28, 28, 64], dtype="uint8", elem_offset=0, align=128, offset_factor=1) + placeholder_33 = T.match_buffer(placeholder_30, [150528], dtype="int16", elem_offset=0, align=128, offset_factor=1) + placeholder_34 = T.match_buffer(placeholder_31, [12288], dtype="int16", elem_offset=0, align=128, offset_factor=1) + placeholder_35 = T.match_buffer(placeholder_32, [64], dtype="int32", elem_offset=0, align=128, offset_factor=1) + T_cast_9 = T.match_buffer(T_cast_8, [121], dtype="uint8", elem_offset=0, align=128, offset_factor=1) # body PaddedInput_2 = T.allocate([150528], "int16", "global") for i0_i1_fused_2 in T.serial(0, 28): for i2_2, i3_2 in T.grid(28, 192): - T.store(PaddedInput_2, (((i0_i1_fused_2*5376) + (i2_2*192)) + i3_2), T.load("int16", placeholder_33.data, (((i0_i1_fused_2*5376) + (i2_2*192)) + i3_2)), True) + PaddedInput_2[(((i0_i1_fused_2*5376) + (i2_2*192)) + i3_2)] = placeholder_33[(((i0_i1_fused_2*5376) + (i2_2*192)) + i3_2)] for ax0_ax1_fused_ax2_fused_2 in T.serial(0, 784): Conv2dOutput_2 = T.allocate([64], "int32", "global") for ff_1 in T.serial(0, 64): - T.store(Conv2dOutput_2, ff_1, 0, True) + Conv2dOutput_2[ff_1] = 0 for rc_2 in T.serial(0, 192): - T.store(Conv2dOutput_2, ff_1, (T.load("int32", Conv2dOutput_2, ff_1) + (T.cast(T.load("int16", PaddedInput_2, ((ax0_ax1_fused_ax2_fused_2*192) + rc_2)), "int32")*T.cast(T.load("int16", placeholder_34.data, ((rc_2*64) + ff_1)), "int32"))), True) + Conv2dOutput_2[ff_1] = (Conv2dOutput_2[ff_1] + (T.cast(PaddedInput_2[((ax0_ax1_fused_ax2_fused_2*192) + rc_2)], "int32")*T.cast(placeholder_34[((rc_2*64) + ff_1)], "int32"))) for ax3_inner_4 in T.serial(0, 64): - T.store(T_cast_9.data, ((ax0_ax1_fused_ax2_fused_2*64) + ax3_inner_4), T.cast(T.max(T.min(T.q_multiply_shift((T.load("int32", Conv2dOutput_2, ax3_inner_4) + T.load("int32", placeholder_35.data, ax3_inner_4)), 1663316467, 31, -7, dtype="int32"), 255), 0), "uint8"), True) + T_cast_9[((ax0_ax1_fused_ax2_fused_2*64) + ax3_inner_4)] = T.cast(T.max(T.min(T.q_multiply_shift((Conv2dOutput_2[ax3_inner_4] + placeholder_35[ax3_inner_4]), 1663316467, 31, -7, dtype="int32"), 255), 0), "uint8") @T.prim_func def tvmgen_default_fused_nn_max_pool2d_cast_1(placeholder_36: T.handle, T_cast_10: T.handle) -> None: # function attr dict T.func_attr({"global_symbol": "tvmgen_default_fused_nn_max_pool2d_cast_1", "tir.noalias": True}) - placeholder_37 = T.match_buffer(placeholder_36, [1, 28, 28, 192], dtype="uint8", elem_offset=0, align=128, offset_factor=1) - T_cast_11 = T.match_buffer(T_cast_10, [1, 28, 28, 192], dtype="int16", elem_offset=0, align=128, offset_factor=1) + placeholder_37 = T.match_buffer(placeholder_36, [150528], dtype="uint8", elem_offset=0, align=128, offset_factor=1) + T_cast_11 = T.match_buffer(T_cast_10, [249], dtype="int16", elem_offset=0, align=128, offset_factor=1) # body tensor_3 = T.allocate([150528], "uint8", "global") for ax0_ax1_fused_6 in T.serial(0, 28): for ax2_6 in T.serial(0, 28): for ax3_outer_init_1, ax3_inner_init_1 in T.grid(3, 64): - T.store(tensor_3, ((((ax0_ax1_fused_6*5376) + (ax2_6*192)) + (ax3_outer_init_1*64)) + ax3_inner_init_1), T.uint8(0), True) + tensor_3[((((ax0_ax1_fused_6*5376) + (ax2_6*192)) + (ax3_outer_init_1*64)) + ax3_inner_init_1)] = T.uint8(0) for rv0_rv1_fused_2, ax3_outer_2, ax3_inner_5 in T.grid(9, 3, 64): - T.store(tensor_3, ((((ax0_ax1_fused_6*5376) + (ax2_6*192)) + (ax3_outer_2*64)) + ax3_inner_5), T.max(T.load("uint8", tensor_3, ((((ax0_ax1_fused_6*5376) + (ax2_6*192)) + (ax3_outer_2*64)) + ax3_inner_5)), T.if_then_else(((((1 <= (T.floordiv(rv0_rv1_fused_2, 3) + ax0_ax1_fused_6)) and ((T.floordiv(rv0_rv1_fused_2, 3) + ax0_ax1_fused_6) < 29)) and (1 <= (ax2_6 + T.floormod(rv0_rv1_fused_2, 3)))) and ((ax2_6 + T.floormod(rv0_rv1_fused_2, 3)) < 29)), T.load("uint8", placeholder_37.data, (((((((T.floordiv(rv0_rv1_fused_2, 3)*5376) + (ax0_ax1_fused_6*5376)) + (ax2_6*192)) + (T.floormod(rv0_rv1_fused_2, 3)*192)) + (ax3_outer_2*64)) + ax3_inner_5) - 5568)), T.uint8(0), dtype="uint8")), True) + tensor_3[((((ax0_ax1_fused_6*5376) + (ax2_6*192)) + (ax3_outer_2*64)) + ax3_inner_5)] = T.max(tensor_3[((((ax0_ax1_fused_6*5376) + (ax2_6*192)) + (ax3_outer_2*64)) + ax3_inner_5)], T.if_then_else(((((1 <= (T.floordiv(rv0_rv1_fused_2, 3) + ax0_ax1_fused_6)) and ((T.floordiv(rv0_rv1_fused_2, 3) + ax0_ax1_fused_6) < 29)) and (1 <= (ax2_6 + T.floormod(rv0_rv1_fused_2, 3)))) and ((ax2_6 + T.floormod(rv0_rv1_fused_2, 3)) < 29)), placeholder_37[(((((((T.floordiv(rv0_rv1_fused_2, 3)*5376) + (ax0_ax1_fused_6*5376)) + (ax2_6*192)) + (T.floormod(rv0_rv1_fused_2, 3)*192)) + (ax3_outer_2*64)) + ax3_inner_5) - 5568)], T.uint8(0), dtype="uint8")) for ax0_ax1_fused_7 in T.serial(0, 28): for ax2_7, ax3_4 in T.grid(28, 192): - T.store(T_cast_11.data, (((ax0_ax1_fused_7*5376) + (ax2_7*192)) + ax3_4), T.cast(T.load("uint8", tensor_3, (((ax0_ax1_fused_7*5376) + (ax2_7*192)) + ax3_4)), "int16"), True) + T_cast_11[(((ax0_ax1_fused_7*5376) + (ax2_7*192)) + ax3_4)] = T.cast(tensor_3[(((ax0_ax1_fused_7*5376) + (ax2_7*192)) + ax3_4)], "int16") @T.prim_func def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast_fixed_point_multiply_cli_4464294615199028320__2(placeholder_38: T.handle, placeholder_39: T.handle, placeholder_40: T.handle, T_cast_12: T.handle) -> None: # function attr dict T.func_attr({"global_symbol": "tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast_fixed_point_multiply_cli_4464294615199028320__2", "tir.noalias": True}) - placeholder_41 = T.match_buffer(placeholder_38, [1, 28, 28, 192], dtype="int16", elem_offset=0, align=128, offset_factor=1) - placeholder_42 = T.match_buffer(placeholder_39, [1, 1, 192, 32], dtype="int16", elem_offset=0, align=128, offset_factor=1) - placeholder_43 = T.match_buffer(placeholder_40, [1, 1, 1, 32], dtype="int32", elem_offset=0, align=128, offset_factor=1) - T_cast_13 = T.match_buffer(T_cast_12, [1, 28, 28, 32], dtype="uint8", elem_offset=0, align=128, offset_factor=1) + placeholder_41 = T.match_buffer(placeholder_38, [150528], dtype="int16", elem_offset=0, align=128, offset_factor=1) + placeholder_42 = T.match_buffer(placeholder_39, [6144], dtype="int16", elem_offset=0, align=128, offset_factor=1) + placeholder_43 = T.match_buffer(placeholder_40, [32], dtype="int32", elem_offset=0, align=128, offset_factor=1) + T_cast_13 = T.match_buffer(T_cast_12, [89], dtype="uint8", elem_offset=0, align=128, offset_factor=1) # body PaddedInput_3 = T.allocate([150528], "int16", "global") for i0_i1_fused_3 in T.serial(0, 28): for i2_3, i3_3 in T.grid(28, 192): - T.store(PaddedInput_3, (((i0_i1_fused_3*5376) + (i2_3*192)) + i3_3), T.load("int16", placeholder_41.data, (((i0_i1_fused_3*5376) + (i2_3*192)) + i3_3)), True) + PaddedInput_3[(((i0_i1_fused_3*5376) + (i2_3*192)) + i3_3)] = placeholder_41[(((i0_i1_fused_3*5376) + (i2_3*192)) + i3_3)] for ax0_ax1_fused_ax2_fused_3 in T.serial(0, 784): Conv2dOutput_3 = T.allocate([1], "int32", "global") for ax3_5 in T.serial(0, 32): - T.store(Conv2dOutput_3, 0, 0, True) + Conv2dOutput_3[0] = 0 for rc_3 in T.serial(0, 192): - T.store(Conv2dOutput_3, 0, (T.load("int32", Conv2dOutput_3, 0) + (T.cast(T.load("int16", PaddedInput_3, ((ax0_ax1_fused_ax2_fused_3*192) + rc_3)), "int32")*T.cast(T.load("int16", placeholder_42.data, ((rc_3*32) + ax3_5)), "int32"))), True) - T.store(T_cast_13.data, ((ax0_ax1_fused_ax2_fused_3*32) + ax3_5), T.cast(T.max(T.min(T.q_multiply_shift(T.cast(T.cast(T.max(T.min(T.q_multiply_shift((T.load("int32", Conv2dOutput_3, 0) + T.load("int32", placeholder_43.data, ax3_5)), 1811141736, 31, -6, dtype="int32"), 255), 0), "uint8"), "int32"), 1136333842, 31, 0, dtype="int32"), 255), 0), "uint8"), True) + Conv2dOutput_3[0] = (Conv2dOutput_3[0] + (T.cast(PaddedInput_3[((ax0_ax1_fused_ax2_fused_3*192) + rc_3)], "int32")*T.cast(placeholder_42[((rc_3*32) + ax3_5)], "int32"))) + T_cast_13[((ax0_ax1_fused_ax2_fused_3*32) + ax3_5)] = T.cast(T.max(T.min(T.q_multiply_shift(T.cast(T.cast(T.max(T.min(T.q_multiply_shift((Conv2dOutput_3[0] + placeholder_43[ax3_5]), 1811141736, 31, -6, dtype="int32"), 255), 0), "uint8"), "int32"), 1136333842, 31, 0, dtype="int32"), 255), 0), "uint8") @T.prim_func def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast_2(placeholder_44: T.handle, placeholder_45: T.handle, placeholder_46: T.handle, T_cast_14: T.handle) -> None: # function attr dict T.func_attr({"global_symbol": "tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast_2", "tir.noalias": True}) - placeholder_47 = T.match_buffer(placeholder_44, [1, 28, 28, 192], dtype="int16", elem_offset=0, align=128, offset_factor=1) - placeholder_48 = T.match_buffer(placeholder_45, [1, 1, 192, 16], dtype="int16", elem_offset=0, align=128, offset_factor=1) - placeholder_49 = T.match_buffer(placeholder_46, [1, 1, 1, 16], dtype="int32", elem_offset=0, align=128, offset_factor=1) - T_cast_15 = T.match_buffer(T_cast_14, [1, 28, 28, 16], dtype="int16", elem_offset=0, align=128, offset_factor=1) + placeholder_47 = T.match_buffer(placeholder_44, [150528], dtype="int16", elem_offset=0, align=128, offset_factor=1) + placeholder_48 = T.match_buffer(placeholder_45, [3072], dtype="int16", elem_offset=0, align=128, offset_factor=1) + placeholder_49 = T.match_buffer(placeholder_46, [16], dtype="int32", elem_offset=0, align=128, offset_factor=1) + T_cast_15 = T.match_buffer(T_cast_14, [73], dtype="int16", elem_offset=0, align=128, offset_factor=1) # body PaddedInput_4 = T.allocate([150528], "int16", "global") for i0_i1_fused_4 in T.serial(0, 28): for i2_4, i3_4 in T.grid(28, 192): - T.store(PaddedInput_4, (((i0_i1_fused_4*5376) + (i2_4*192)) + i3_4), T.load("int16", placeholder_47.data, (((i0_i1_fused_4*5376) + (i2_4*192)) + i3_4)), True) + PaddedInput_4[(((i0_i1_fused_4*5376) + (i2_4*192)) + i3_4)] = placeholder_47[(((i0_i1_fused_4*5376) + (i2_4*192)) + i3_4)] for ax0_ax1_fused_ax2_fused_4 in T.serial(0, 784): Conv2dOutput_4 = T.allocate([1], "int32", "global") for ax3_6 in T.serial(0, 16): - T.store(Conv2dOutput_4, 0, 0, True) + Conv2dOutput_4[0] = 0 for rc_4 in T.serial(0, 192): - T.store(Conv2dOutput_4, 0, (T.load("int32", Conv2dOutput_4, 0) + (T.cast(T.load("int16", PaddedInput_4, ((ax0_ax1_fused_ax2_fused_4*192) + rc_4)), "int32")*T.cast(T.load("int16", placeholder_48.data, ((rc_4*16) + ax3_6)), "int32"))), True) - T.store(T_cast_15.data, ((ax0_ax1_fused_ax2_fused_4*16) + ax3_6), T.cast(T.cast(T.max(T.min(T.q_multiply_shift((T.load("int32", Conv2dOutput_4, 0) + T.load("int32", placeholder_49.data, ax3_6)), 1764006585, 31, -7, dtype="int32"), 255), 0), "uint8"), "int16"), True) + Conv2dOutput_4[0] = (Conv2dOutput_4[0] + (T.cast(PaddedInput_4[((ax0_ax1_fused_ax2_fused_4*192) + rc_4)], "int32")*T.cast(placeholder_48[((rc_4*16) + ax3_6)], "int32"))) + T_cast_15[((ax0_ax1_fused_ax2_fused_4*16) + ax3_6)] = T.cast(T.cast(T.max(T.min(T.q_multiply_shift((Conv2dOutput_4[0] + placeholder_49[ax3_6]), 1764006585, 31, -7, dtype="int32"), 255), 0), "uint8"), "int16") @T.prim_func def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast_fixed_point_multiply_cli_4464294615199028320__1(placeholder_50: T.handle, placeholder_51: T.handle, placeholder_52: T.handle, T_cast_16: T.handle) -> None: # function attr dict T.func_attr({"global_symbol": "tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast_fixed_point_multiply_cli_4464294615199028320__1", "tir.noalias": True}) - placeholder_53 = T.match_buffer(placeholder_50, [1, 28, 28, 16], dtype="int16", elem_offset=0, align=128, offset_factor=1) - placeholder_54 = T.match_buffer(placeholder_51, [3, 3, 16, 32], dtype="int16", elem_offset=0, align=128, offset_factor=1) - placeholder_55 = T.match_buffer(placeholder_52, [1, 1, 1, 32], dtype="int32", elem_offset=0, align=128, offset_factor=1) - T_cast_17 = T.match_buffer(T_cast_16, [1, 28, 28, 32], dtype="uint8", elem_offset=0, align=128, offset_factor=1) + placeholder_53 = T.match_buffer(placeholder_50, [12544], dtype="int16", elem_offset=0, align=128, offset_factor=1) + placeholder_54 = T.match_buffer(placeholder_51, [4608], dtype="int16", elem_offset=0, align=128, offset_factor=1) + placeholder_55 = T.match_buffer(placeholder_52, [32], dtype="int32", elem_offset=0, align=128, offset_factor=1) + T_cast_17 = T.match_buffer(T_cast_16, [89], dtype="uint8", elem_offset=0, align=128, offset_factor=1) # body PaddedInput_5 = T.allocate([14400], "int16", "global") for i0_i1_fused_5 in T.serial(0, 30): for i2_5, i3_5 in T.grid(30, 16): - T.store(PaddedInput_5, (((i0_i1_fused_5*480) + (i2_5*16)) + i3_5), T.if_then_else(((((1 <= i0_i1_fused_5) and (i0_i1_fused_5 < 29)) and (1 <= i2_5)) and (i2_5 < 29)), T.load("int16", placeholder_53.data, ((((i0_i1_fused_5*448) + (i2_5*16)) + i3_5) - 464)), T.int16(0), dtype="int16"), True) + PaddedInput_5[(((i0_i1_fused_5*480) + (i2_5*16)) + i3_5)] = T.if_then_else(((((1 <= i0_i1_fused_5) and (i0_i1_fused_5 < 29)) and (1 <= i2_5)) and (i2_5 < 29)), placeholder_53[((((i0_i1_fused_5*448) + (i2_5*16)) + i3_5) - 464)], T.int16(0), dtype="int16") for ax0_ax1_fused_ax2_fused_5 in T.serial(0, 784): Conv2dOutput_5 = T.allocate([1], "int32", "global") for ax3_7 in T.serial(0, 32): - T.store(Conv2dOutput_5, 0, 0, True) + Conv2dOutput_5[0] = 0 for ry, rx, rc_5 in T.grid(3, 3, 16): - T.store(Conv2dOutput_5, 0, (T.load("int32", Conv2dOutput_5, 0) + (T.cast(T.load("int16", PaddedInput_5, (((((T.floordiv(ax0_ax1_fused_ax2_fused_5, 28)*480) + (ry*480)) + (rx*16)) + (T.floormod(ax0_ax1_fused_ax2_fused_5, 28)*16)) + rc_5)), "int32")*T.cast(T.load("int16", placeholder_54.data, ((((ry*1536) + (rx*512)) + (rc_5*32)) + ax3_7)), "int32"))), True) - T.store(T_cast_17.data, ((ax0_ax1_fused_ax2_fused_5*32) + ax3_7), T.cast(T.max(T.min(T.q_multiply_shift(T.cast(T.cast(T.max(T.min(T.q_multiply_shift((T.load("int32", Conv2dOutput_5, 0) + T.load("int32", placeholder_55.data, ax3_7)), 1131968888, 31, -6, dtype="int32"), 255), 0), "uint8"), "int32"), 1900719667, 31, 0, dtype="int32"), 255), 0), "uint8"), True) + Conv2dOutput_5[0] = (Conv2dOutput_5[0] + (T.cast(PaddedInput_5[(((((T.floordiv(ax0_ax1_fused_ax2_fused_5, 28)*480) + (ry*480)) + (rx*16)) + (T.floormod(ax0_ax1_fused_ax2_fused_5, 28)*16)) + rc_5)], "int32")*T.cast(placeholder_54[((((ry*1536) + (rx*512)) + (rc_5*32)) + ax3_7)], "int32"))) + T_cast_17[((ax0_ax1_fused_ax2_fused_5*32) + ax3_7)] = T.cast(T.max(T.min(T.q_multiply_shift(T.cast(T.cast(T.max(T.min(T.q_multiply_shift((Conv2dOutput_5[0] + placeholder_55[ax3_7]), 1131968888, 31, -6, dtype="int32"), 255), 0), "uint8"), "int32"), 1900719667, 31, 0, dtype="int32"), 255), 0), "uint8") @T.prim_func def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast_fixed_point_multiply_cli_4464294615199028320_(placeholder_56: T.handle, placeholder_57: T.handle, placeholder_58: T.handle, T_cast_18: T.handle) -> None: # function attr dict T.func_attr({"global_symbol": "tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast_fixed_point_multiply_cli_4464294615199028320_", "tir.noalias": True}) - placeholder_59 = T.match_buffer(placeholder_56, [1, 28, 28, 96], dtype="int16", elem_offset=0, align=128, offset_factor=1) - placeholder_60 = T.match_buffer(placeholder_57, [3, 3, 96, 128], dtype="int16", elem_offset=0, align=128, offset_factor=1) - placeholder_61 = T.match_buffer(placeholder_58, [1, 1, 1, 128], dtype="int32", elem_offset=0, align=128, offset_factor=1) - T_cast_19 = T.match_buffer(T_cast_18, [1, 28, 28, 128], dtype="uint8", elem_offset=0, align=128, offset_factor=1) + placeholder_59 = T.match_buffer(placeholder_56, [75264], dtype="int16", elem_offset=0, align=128, offset_factor=1) + placeholder_60 = T.match_buffer(placeholder_57, [110592], dtype="int16", elem_offset=0, align=128, offset_factor=1) + placeholder_61 = T.match_buffer(placeholder_58, [128], dtype="int32", elem_offset=0, align=128, offset_factor=1) + T_cast_19 = T.match_buffer(T_cast_18, [185], dtype="uint8", elem_offset=0, align=128, offset_factor=1) # body PaddedInput_6 = T.allocate([86400], "int16", "global") for i0_i1_fused_6 in T.serial(0, 30): for i2_6, i3_6 in T.grid(30, 96): - T.store(PaddedInput_6, (((i0_i1_fused_6*2880) + (i2_6*96)) + i3_6), T.if_then_else(((((1 <= i0_i1_fused_6) and (i0_i1_fused_6 < 29)) and (1 <= i2_6)) and (i2_6 < 29)), T.load("int16", placeholder_59.data, ((((i0_i1_fused_6*2688) + (i2_6*96)) + i3_6) - 2784)), T.int16(0), dtype="int16"), True) + PaddedInput_6[(((i0_i1_fused_6*2880) + (i2_6*96)) + i3_6)] = T.if_then_else(((((1 <= i0_i1_fused_6) and (i0_i1_fused_6 < 29)) and (1 <= i2_6)) and (i2_6 < 29)), placeholder_59[((((i0_i1_fused_6*2688) + (i2_6*96)) + i3_6) - 2784)], T.int16(0), dtype="int16") for ax0_ax1_fused_ax2_fused_6 in T.serial(0, 784): Conv2dOutput_6 = T.allocate([64], "int32", "global") for ax3_outer_3 in T.serial(0, 2): for ff_2 in T.serial(0, 64): - T.store(Conv2dOutput_6, ff_2, 0, True) + Conv2dOutput_6[ff_2] = 0 for ry_1, rx_1, rc_6 in T.grid(3, 3, 96): - T.store(Conv2dOutput_6, ff_2, (T.load("int32", Conv2dOutput_6, ff_2) + (T.cast(T.load("int16", PaddedInput_6, (((((T.floordiv(ax0_ax1_fused_ax2_fused_6, 28)*2880) + (ry_1*2880)) + (rx_1*96)) + (T.floormod(ax0_ax1_fused_ax2_fused_6, 28)*96)) + rc_6)), "int32")*T.cast(T.load("int16", placeholder_60.data, (((((ry_1*36864) + (rx_1*12288)) + (rc_6*128)) + (ax3_outer_3*64)) + ff_2)), "int32"))), True) + Conv2dOutput_6[ff_2] = (Conv2dOutput_6[ff_2] + (T.cast(PaddedInput_6[(((((T.floordiv(ax0_ax1_fused_ax2_fused_6, 28)*2880) + (ry_1*2880)) + (rx_1*96)) + (T.floormod(ax0_ax1_fused_ax2_fused_6, 28)*96)) + rc_6)], "int32")*T.cast(placeholder_60[(((((ry_1*36864) + (rx_1*12288)) + (rc_6*128)) + (ax3_outer_3*64)) + ff_2)], "int32"))) for ax3_inner_6 in T.serial(0, 64): - T.store(T_cast_19.data, (((ax0_ax1_fused_ax2_fused_6*128) + (ax3_outer_3*64)) + ax3_inner_6), T.cast(T.max(T.min(T.q_multiply_shift(T.cast(T.cast(T.max(T.min(T.q_multiply_shift((T.load("int32", Conv2dOutput_6, ax3_inner_6) + T.load("int32", placeholder_61.data, ((ax3_outer_3*64) + ax3_inner_6))), 1374050734, 31, -7, dtype="int32"), 255), 0), "uint8"), "int32"), 1544713713, 31, 0, dtype="int32"), 255), 0), "uint8"), True) + T_cast_19[(((ax0_ax1_fused_ax2_fused_6*128) + (ax3_outer_3*64)) + ax3_inner_6)] = T.cast(T.max(T.min(T.q_multiply_shift(T.cast(T.cast(T.max(T.min(T.q_multiply_shift((Conv2dOutput_6[ax3_inner_6] + placeholder_61[((ax3_outer_3*64) + ax3_inner_6)]), 1374050734, 31, -7, dtype="int32"), 255), 0), "uint8"), "int32"), 1544713713, 31, 0, dtype="int32"), 255), 0), "uint8") @T.prim_func def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast(placeholder_62: T.handle, placeholder_63: T.handle, placeholder_64: T.handle, T_cast_20: T.handle) -> None: # function attr dict T.func_attr({"global_symbol": "tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast", "T.noalias": True}) - placeholder_65 = T.match_buffer(placeholder_62, [1, 224, 224, 3], dtype="int16", elem_offset=0, align=128, offset_factor=1) - placeholder_66 = T.match_buffer(placeholder_63, [7, 7, 3, 64], dtype="int16", elem_offset=0, align=128, offset_factor=1) - placeholder_67 = T.match_buffer(placeholder_64, [1, 1, 1, 64], dtype="int32", elem_offset=0, align=128, offset_factor=1) - T_cast_21 = T.match_buffer(T_cast_20, [1, 112, 112, 64], dtype="uint8", elem_offset=0, align=128, offset_factor=1) + placeholder_65 = T.match_buffer(placeholder_62, [150528], dtype="int16", elem_offset=0, align=128, offset_factor=1) + placeholder_66 = T.match_buffer(placeholder_63, [9408], dtype="int16", elem_offset=0, align=128, offset_factor=1) + placeholder_67 = T.match_buffer(placeholder_64, [64], dtype="int32", elem_offset=0, align=128, offset_factor=1) + T_cast_21 = T.match_buffer(T_cast_20, [289], dtype="uint8", elem_offset=0, align=128, offset_factor=1) # body PaddedInput_7 = T.allocate([157323], "int16", "global") for i0_i1_fused_7 in T.serial(0, 229): for i2_7, i3_7 in T.grid(229, 3): - T.store(PaddedInput_7, (((i0_i1_fused_7*687) + (i2_7*3)) + i3_7), T.if_then_else(((((2 <= i0_i1_fused_7) and (i0_i1_fused_7 < 226)) and (2 <= i2_7)) and (i2_7 < 226)), T.load("int16", placeholder_65.data, ((((i0_i1_fused_7*672) + (i2_7*3)) + i3_7) - 1350)), T.int16(0), dtype="int16"), True) + PaddedInput_7[(((i0_i1_fused_7*687) + (i2_7*3)) + i3_7)] = T.if_then_else(((((2 <= i0_i1_fused_7) and (i0_i1_fused_7 < 226)) and (2 <= i2_7)) and (i2_7 < 226)), placeholder_65[((((i0_i1_fused_7*672) + (i2_7*3)) + i3_7) - 1350)], T.int16(0), dtype="int16") for ax0_ax1_fused_ax2_fused_7 in T.serial(0, 12544): Conv2dOutput_7 = T.allocate([64], "int32", "global") for ff_3 in T.serial(0, 64): - T.store(Conv2dOutput_7, ff_3, 0, True) + Conv2dOutput_7[ff_3] = 0 for ry_2, rx_2, rc_7 in T.grid(7, 7, 3): - T.store(Conv2dOutput_7, ff_3, (T.load("int32", Conv2dOutput_7, ff_3) + (T.cast(T.load("int16", PaddedInput_7, (((((T.floordiv(ax0_ax1_fused_ax2_fused_7, 112)*1374) + (ry_2*687)) + (T.floormod(ax0_ax1_fused_ax2_fused_7, 112)*6)) + (rx_2*3)) + rc_7)), "int32")*T.cast(T.load("int16", placeholder_66.data, ((((ry_2*1344) + (rx_2*192)) + (rc_7*64)) + ff_3)), "int32"))), True) + Conv2dOutput_7[ff_3] = (Conv2dOutput_7[ff_3] + (T.cast(PaddedInput_7[(((((T.floordiv(ax0_ax1_fused_ax2_fused_7, 112)*1374) + (ry_2*687)) + (T.floormod(ax0_ax1_fused_ax2_fused_7, 112)*6)) + (rx_2*3)) + rc_7)], "int32")*T.cast(placeholder_66[((((ry_2*1344) + (rx_2*192)) + (rc_7*64)) + ff_3)], "int32"))) for ax3_inner_7 in T.serial(0, 64): - T.store(T_cast_21.data, ((ax0_ax1_fused_ax2_fused_7*64) + ax3_inner_7), T.cast(T.max(T.min(T.q_multiply_shift((T.load("int32", Conv2dOutput_7, ax3_inner_7) + T.load("int32", placeholder_67.data, ax3_inner_7)), 1939887962, 31, -9, dtype="int32"), 255), 0), "uint8"), True) + T_cast_21[((ax0_ax1_fused_ax2_fused_7*64) + ax3_inner_7)] = T.cast(T.max(T.min(T.q_multiply_shift((Conv2dOutput_7[ax3_inner_7] + placeholder_67[ax3_inner_7]), 1939887962, 31, -9, dtype="int32"), 255), 0), "uint8") @T.prim_func def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_1(placeholder_68: T.handle, placeholder_69: T.handle, placeholder_70: T.handle, T_cast_22: T.handle) -> None: # function attr dict T.func_attr({"global_symbol": "tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_1", "tir.noalias": True}) - placeholder_71 = T.match_buffer(placeholder_68, [1, 56, 56, 64], dtype="int16", elem_offset=0, align=128, offset_factor=1) - placeholder_72 = T.match_buffer(placeholder_69, [3, 3, 64, 192], dtype="int16", elem_offset=0, align=128, offset_factor=1) - placeholder_73 = T.match_buffer(placeholder_70, [1, 1, 1, 192], dtype="int32", elem_offset=0, align=128, offset_factor=1) - T_cast_23 = T.match_buffer(T_cast_22, [1, 56, 56, 192], dtype="uint8", elem_offset=0, align=128, offset_factor=1) + placeholder_71 = T.match_buffer(placeholder_68, [200704], dtype="int16", elem_offset=0, align=128, offset_factor=1) + placeholder_72 = T.match_buffer(placeholder_69, [110592], dtype="int16", elem_offset=0, align=128, offset_factor=1) + placeholder_73 = T.match_buffer(placeholder_70, [192], dtype="int32", elem_offset=0, align=128, offset_factor=1) + T_cast_23 = T.match_buffer(T_cast_22, [305], dtype="uint8", elem_offset=0, align=128, offset_factor=1) # body PaddedInput_8 = T.allocate([215296], "int16", "global") for i0_i1_fused_8 in T.serial(0, 58): for i2_8, i3_8 in T.grid(58, 64): - T.store(PaddedInput_8, (((i0_i1_fused_8*3712) + (i2_8*64)) + i3_8), T.if_then_else(((((1 <= i0_i1_fused_8) and (i0_i1_fused_8 < 57)) and (1 <= i2_8)) and (i2_8 < 57)), T.load("int16", placeholder_71.data, ((((i0_i1_fused_8*3584) + (i2_8*64)) + i3_8) - 3648)), T.int16(0), dtype="int16"), True) + PaddedInput_8[(((i0_i1_fused_8*3712) + (i2_8*64)) + i3_8)] = T.if_then_else(((((1 <= i0_i1_fused_8) and (i0_i1_fused_8 < 57)) and (1 <= i2_8)) and (i2_8 < 57)), placeholder_71[((((i0_i1_fused_8*3584) + (i2_8*64)) + i3_8) - 3648)], T.int16(0), dtype="int16") for ax0_ax1_fused_ax2_fused_8 in T.serial(0, 3136): Conv2dOutput_8 = T.allocate([64], "int32", "global") for ax3_outer_4 in T.serial(0, 3): for ff_4 in T.serial(0, 64): - T.store(Conv2dOutput_8, ff_4, 0, True) + Conv2dOutput_8[ff_4] = 0 for ry_3, rx_3, rc_8 in T.grid(3, 3, 64): - T.store(Conv2dOutput_8, ff_4, (T.load("int32", Conv2dOutput_8, ff_4) + (T.cast(T.load("int16", PaddedInput_8, (((((T.floordiv(ax0_ax1_fused_ax2_fused_8, 56)*3712) + (ry_3*3712)) + (rx_3*64)) + (T.floormod(ax0_ax1_fused_ax2_fused_8, 56)*64)) + rc_8)), "int32")*T.cast(T.load("int16", placeholder_72.data, (((((ry_3*36864) + (rx_3*12288)) + (rc_8*192)) + (ax3_outer_4*64)) + ff_4)), "int32"))), True) + Conv2dOutput_8[ff_4] = (Conv2dOutput_8[ff_4] + (T.cast(PaddedInput_8[(((((T.floordiv(ax0_ax1_fused_ax2_fused_8, 56)*3712) + (ry_3*3712)) + (rx_3*64)) + (T.floormod(ax0_ax1_fused_ax2_fused_8, 56)*64)) + rc_8)], "int32")*T.cast(placeholder_72[(((((ry_3*36864) + (rx_3*12288)) + (rc_8*192)) + (ax3_outer_4*64)) + ff_4)], "int32"))) for ax3_inner_8 in T.serial(0, 64): - T.store(T_cast_23.data, (((ax0_ax1_fused_ax2_fused_8*192) + (ax3_outer_4*64)) + ax3_inner_8), T.cast(T.max(T.min(T.q_multiply_shift((T.load("int32", Conv2dOutput_8, ax3_inner_8) + T.load("int32", placeholder_73.data, ((ax3_outer_4*64) + ax3_inner_8))), 1139793473, 31, -6, dtype="int32"), 255), 0), "uint8"), True) + T_cast_23[(((ax0_ax1_fused_ax2_fused_8*192) + (ax3_outer_4*64)) + ax3_inner_8)] = T.cast(T.max(T.min(T.q_multiply_shift((Conv2dOutput_8[ax3_inner_8] + placeholder_73[((ax3_outer_4*64) + ax3_inner_8)]), 1139793473, 31, -6, dtype="int32"), 255), 0), "uint8") @T.prim_func def run_model(input: T.handle, output: T.handle) -> None: @@ -630,21 +630,21 @@ def run_model(input: T.handle, output: T.handle) -> None: sid_25 = T.allocate([25088], "int8", "global") sid_26 = T.allocate([25088], "int8", "global") sid_31 = T.allocate([25088], "int8", "global") - T.evaluate(T.call_extern("tvmgen_default_fused_cast_subtract", input, T.lookup_param("p0", dtype="handle"), sid_9, dtype="int32")) - T.evaluate(T.call_extern("tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast", sid_9, T.lookup_param("p1", dtype="handle"), T.lookup_param("p2", dtype="handle"), sid_8, dtype="int32")) - T.evaluate(T.call_extern("tvmgen_default_fused_nn_max_pool2d_cast", sid_8, sid_7, dtype="int32")) - T.evaluate(T.call_extern("tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast", sid_7, T.lookup_param("p3", dtype="handle"), T.lookup_param("p4", dtype="handle"), sid_6, dtype="int32")) - T.evaluate(T.call_extern("tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_1", sid_6, T.lookup_param("p5", dtype="handle"), T.lookup_param("p6", dtype="handle"), sid_5, dtype="int32")) - T.evaluate(T.call_extern("tvmgen_default_fused_nn_max_pool2d", sid_5, sid_4, dtype="int32")) - T.evaluate(T.call_extern("tvmgen_default_fused_cast", sid_4, sid_3, dtype="int32")) - T.evaluate(T.call_extern("tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_2", sid_3, T.lookup_param("p7", dtype="handle"), T.lookup_param("p8", dtype="handle"), sid_2, dtype="int32")) - T.evaluate(T.call_extern("tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast_1", sid_3, T.lookup_param("p9", dtype="handle"), T.lookup_param("p10", dtype="handle"), sid_20, dtype="int32")) - T.evaluate(T.call_extern("tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast_fixed_point_multiply_cli_4464294615199028320_", sid_20, T.lookup_param("p11", dtype="handle"), T.lookup_param("p12", dtype="handle"), sid_19, dtype="int32")) - T.evaluate(T.call_extern("tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast_2", sid_3, T.lookup_param("p13", dtype="handle"), T.lookup_param("p14", dtype="handle"), sid_26, dtype="int32")) - T.evaluate(T.call_extern("tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast_fixed_point_multiply_cli_4464294615199028320__1", sid_26, T.lookup_param("p15", dtype="handle"), T.lookup_param("p16", dtype="handle"), sid_25, dtype="int32")) - T.evaluate(T.call_extern("tvmgen_default_fused_nn_max_pool2d_cast_1", sid_4, sid_32, dtype="int32")) - T.evaluate(T.call_extern("tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast_fixed_point_multiply_cli_4464294615199028320__2", sid_32, T.lookup_param("p17", dtype="handle"), T.lookup_param("p18", dtype="handle"), sid_31, dtype="int32")) - T.evaluate(T.call_extern("tvmgen_default_fused_concatenate", sid_2, sid_19, sid_25, sid_31, output, dtype="int32")) + T.evaluate(T.call_extern("tvmgen_default_fused_cast_subtract", input, T.lookup_param("p0", dtype="handle"), sid_9.data, dtype="int32")) + T.evaluate(T.call_extern("tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast", sid_9.data, T.lookup_param("p1", dtype="handle"), T.lookup_param("p2", dtype="handle"), sid_8.data, dtype="int32")) + T.evaluate(T.call_extern("tvmgen_default_fused_nn_max_pool2d_cast", sid_8.data, sid_7.data, dtype="int32")) + T.evaluate(T.call_extern("tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast", sid_7.data, T.lookup_param("p3", dtype="handle"), T.lookup_param("p4", dtype="handle"), sid_6.data, dtype="int32")) + T.evaluate(T.call_extern("tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_1", sid_6.data, T.lookup_param("p5", dtype="handle"), T.lookup_param("p6", dtype="handle"), sid_5.data, dtype="int32")) + T.evaluate(T.call_extern("tvmgen_default_fused_nn_max_pool2d", sid_5.data, sid_4.data, dtype="int32")) + T.evaluate(T.call_extern("tvmgen_default_fused_cast", sid_4.data, sid_3.data, dtype="int32")) + T.evaluate(T.call_extern("tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_2", sid_3.data, T.lookup_param("p7", dtype="handle"), T.lookup_param("p8", dtype="handle"), sid_2.data, dtype="int32")) + T.evaluate(T.call_extern("tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast_1", sid_3.data, T.lookup_param("p9", dtype="handle"), T.lookup_param("p10", dtype="handle"), sid_20.data, dtype="int32")) + T.evaluate(T.call_extern("tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast_fixed_point_multiply_cli_4464294615199028320_", sid_20.data, T.lookup_param("p11", dtype="handle"), T.lookup_param("p12", dtype="handle"), sid_19.data, dtype="int32")) + T.evaluate(T.call_extern("tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast_2", sid_3.data, T.lookup_param("p13", dtype="handle"), T.lookup_param("p14", dtype="handle"), sid_26.data, dtype="int32")) + T.evaluate(T.call_extern("tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast_fixed_point_multiply_cli_4464294615199028320__1", sid_26.data, T.lookup_param("p15", dtype="handle"), T.lookup_param("p16", dtype="handle"), sid_25.data, dtype="int32")) + T.evaluate(T.call_extern("tvmgen_default_fused_nn_max_pool2d_cast_1", sid_4.data, sid_32.data, dtype="int32")) + T.evaluate(T.call_extern("tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast_fixed_point_multiply_cli_4464294615199028320__2", sid_32.data, T.lookup_param("p17", dtype="handle"), T.lookup_param("p18", dtype="handle"), sid_31.data, dtype="int32")) + T.evaluate(T.call_extern("tvmgen_default_fused_concatenate", sid_2.data, sid_19.data, sid_25.data, sid_31.data, output, dtype="int32")) __tvm_meta__ = None # fmt: on @@ -1107,231 +1107,231 @@ class MultipleCallsToSamePrimFuncModule: def tvmgen_default_fused_layout_transform_1(placeholder: T.handle, T_layout_trans: T.handle) -> None: # function attr dict T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "tvmgen_default_fused_layout_transform_1", "tir.noalias": True}) - placeholder_1 = T.match_buffer(placeholder, [1, 3, 24, 12], dtype="float32") - T_layout_trans_1 = T.match_buffer(T_layout_trans, [1, 1, 24, 12, 3], dtype="float32") + placeholder_1 = T.match_buffer(placeholder, [864], dtype="float32") + T_layout_trans_1 = T.match_buffer(T_layout_trans, [41], dtype="float32") # body for ax0_ax1_fused_ax2_fused, ax3, ax4_inner in T.grid(24, 12, 3): - T.store(T_layout_trans_1.data, ax0_ax1_fused_ax2_fused * 36 + ax3 * 3 + ax4_inner, T.load("float32", placeholder_1.data, ax4_inner * 288 + ax0_ax1_fused_ax2_fused * 12 + ax3), True) + T_layout_trans_1[ax0_ax1_fused_ax2_fused * 36 + ax3 * 3 + ax4_inner] = placeholder_1[ax4_inner * 288 + ax0_ax1_fused_ax2_fused * 12 + ax3] @T.prim_func def tvmgen_default_fused_nn_contrib_conv2d_NCHWc(placeholder_2: T.handle, placeholder_3: T.handle, conv2d_NCHWc: T.handle) -> None: # function attr dict T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "tvmgen_default_fused_nn_contrib_conv2d_NCHWc", "tir.noalias": True}) - placeholder_4 = T.match_buffer(placeholder_2, [1, 1, 24, 12, 3], dtype="float32") - placeholder_5 = T.match_buffer(placeholder_3, [1, 1, 3, 3, 3, 3], dtype="float32") - conv2d_NCHWc_1 = T.match_buffer(conv2d_NCHWc, [1, 1, 24, 12, 3], dtype="float32") + placeholder_4 = T.match_buffer(placeholder_2, [864], dtype="float32") + placeholder_5 = T.match_buffer(placeholder_3, [81], dtype="float32") + conv2d_NCHWc_1 = T.match_buffer(conv2d_NCHWc, [41], dtype="float32") # body - data_pad = T.allocate([1, 1, 26, 14, 3], "float32", "global") + data_pad = T.allocate([1092], "float32", "global") for i0_i1_fused_i2_fused, i3, i4 in T.grid(26, 14, 3): - T.store(data_pad, i0_i1_fused_i2_fused * 42 + i3 * 3 + i4, T.if_then_else(1 <= i0_i1_fused_i2_fused and i0_i1_fused_i2_fused < 25 and 1 <= i3 and i3 < 13, T.load("float32", placeholder_4.data, i0_i1_fused_i2_fused * 36 + i3 * 3 + i4 - 39), T.float32(0), dtype="float32"), True) + data_pad[i0_i1_fused_i2_fused * 42 + i3 * 3 + i4] = T.if_then_else(1 <= i0_i1_fused_i2_fused and i0_i1_fused_i2_fused < 25 and 1 <= i3 and i3 < 13, placeholder_4[i0_i1_fused_i2_fused * 36 + i3 * 3 + i4 - 39], T.float32(0), dtype="float32") for n_oc_chunk_fused_oh_fused in T.serial(0, 24): - conv2d_NCHWc_global = T.allocate([1, 1, 1, 12, 3], "float32", "global") + conv2d_NCHWc_global = T.allocate([36], "float32", "global") for oc_block_c_init in T.serial(0, 3): - T.store(conv2d_NCHWc_global, oc_block_c_init, T.float32(0), True) + conv2d_NCHWc_global[oc_block_c_init] = T.float32(0) for oc_block_c_init in T.serial(0, 3): - T.store(conv2d_NCHWc_global, oc_block_c_init + 3, T.float32(0), True) + conv2d_NCHWc_global[oc_block_c_init + 3] = T.float32(0) for oc_block_c_init in T.serial(0, 3): - T.store(conv2d_NCHWc_global, oc_block_c_init + 6, T.float32(0), True) + conv2d_NCHWc_global[oc_block_c_init + 6] = T.float32(0) for oc_block_c_init in T.serial(0, 3): - T.store(conv2d_NCHWc_global, oc_block_c_init + 9, T.float32(0), True) + conv2d_NCHWc_global[oc_block_c_init + 9] = T.float32(0) for oc_block_c_init in T.serial(0, 3): - T.store(conv2d_NCHWc_global, oc_block_c_init + 12, T.float32(0), True) + conv2d_NCHWc_global[oc_block_c_init + 12] = T.float32(0) for oc_block_c_init in T.serial(0, 3): - T.store(conv2d_NCHWc_global, oc_block_c_init + 15, T.float32(0), True) + conv2d_NCHWc_global[oc_block_c_init + 15] = T.float32(0) for oc_block_c_init in T.serial(0, 3): - T.store(conv2d_NCHWc_global, oc_block_c_init + 18, T.float32(0), True) + conv2d_NCHWc_global[oc_block_c_init + 18] = T.float32(0) for oc_block_c_init in T.serial(0, 3): - T.store(conv2d_NCHWc_global, oc_block_c_init + 21, T.float32(0), True) + conv2d_NCHWc_global[oc_block_c_init + 21] = T.float32(0) for oc_block_c_init in T.serial(0, 3): - T.store(conv2d_NCHWc_global, oc_block_c_init + 24, T.float32(0), True) + conv2d_NCHWc_global[oc_block_c_init + 24] = T.float32(0) for oc_block_c_init in T.serial(0, 3): - T.store(conv2d_NCHWc_global, oc_block_c_init + 27, T.float32(0), True) + conv2d_NCHWc_global[oc_block_c_init + 27] = T.float32(0) for oc_block_c_init in T.serial(0, 3): - T.store(conv2d_NCHWc_global, oc_block_c_init + 30, T.float32(0), True) + conv2d_NCHWc_global[oc_block_c_init + 30] = T.float32(0) for oc_block_c_init in T.serial(0, 3): - T.store(conv2d_NCHWc_global, oc_block_c_init + 33, T.float32(0), True) + conv2d_NCHWc_global[oc_block_c_init + 33] = T.float32(0) for kh, kw, ic_inner in T.grid(3, 3, 3): for oc_block_c in T.serial(0, 3): - T.store(conv2d_NCHWc_global, oc_block_c, T.load("float32", conv2d_NCHWc_global, oc_block_c) + T.load("float32", data_pad, kh * 42 + n_oc_chunk_fused_oh_fused * 42 + kw * 3 + ic_inner) * T.load("float32", placeholder_5.data, kh * 27 + kw * 9 + ic_inner * 3 + oc_block_c), True) + conv2d_NCHWc_global[oc_block_c] = conv2d_NCHWc_global[oc_block_c] + data_pad[kh * 42 + n_oc_chunk_fused_oh_fused * 42 + kw * 3 + ic_inner] * placeholder_5[kh * 27 + kw * 9 + ic_inner * 3 + oc_block_c] for oc_block_c in T.serial(0, 3): - T.store(conv2d_NCHWc_global, oc_block_c + 3, T.load("float32", conv2d_NCHWc_global, oc_block_c + 3) + T.load("float32", data_pad, kh * 42 + n_oc_chunk_fused_oh_fused * 42 + kw * 3 + ic_inner + 3) * T.load("float32", placeholder_5.data, kh * 27 + kw * 9 + ic_inner * 3 + oc_block_c), True) + conv2d_NCHWc_global[oc_block_c + 3] = conv2d_NCHWc_global[oc_block_c + 3] + data_pad[kh * 42 + n_oc_chunk_fused_oh_fused * 42 + kw * 3 + ic_inner + 3] * placeholder_5[kh * 27 + kw * 9 + ic_inner * 3 + oc_block_c] for oc_block_c in T.serial(0, 3): - T.store(conv2d_NCHWc_global, oc_block_c + 6, T.load("float32", conv2d_NCHWc_global, oc_block_c + 6) + T.load("float32", data_pad, kh * 42 + n_oc_chunk_fused_oh_fused * 42 + kw * 3 + ic_inner + 6) * T.load("float32", placeholder_5.data, kh * 27 + kw * 9 + ic_inner * 3 + oc_block_c), True) + conv2d_NCHWc_global[oc_block_c + 6] = conv2d_NCHWc_global[oc_block_c + 6] + data_pad[kh * 42 + n_oc_chunk_fused_oh_fused * 42 + kw * 3 + ic_inner + 6] * placeholder_5[kh * 27 + kw * 9 + ic_inner * 3 + oc_block_c] for oc_block_c in T.serial(0, 3): - T.store(conv2d_NCHWc_global, oc_block_c + 9, T.load("float32", conv2d_NCHWc_global, oc_block_c + 9) + T.load("float32", data_pad, kh * 42 + n_oc_chunk_fused_oh_fused * 42 + kw * 3 + ic_inner + 9) * T.load("float32", placeholder_5.data, kh * 27 + kw * 9 + ic_inner * 3 + oc_block_c), True) + conv2d_NCHWc_global[oc_block_c + 9] = conv2d_NCHWc_global[oc_block_c + 9] + data_pad[kh * 42 + n_oc_chunk_fused_oh_fused * 42 + kw * 3 + ic_inner + 9] * placeholder_5[kh * 27 + kw * 9 + ic_inner * 3 + oc_block_c] for oc_block_c in T.serial(0, 3): - T.store(conv2d_NCHWc_global, oc_block_c + 12, T.load("float32", conv2d_NCHWc_global, oc_block_c + 12) + T.load("float32", data_pad, kh * 42 + n_oc_chunk_fused_oh_fused * 42 + kw * 3 + ic_inner + 12) * T.load("float32", placeholder_5.data, kh * 27 + kw * 9 + ic_inner * 3 + oc_block_c), True) + conv2d_NCHWc_global[oc_block_c + 12] = conv2d_NCHWc_global[oc_block_c + 12] + data_pad[kh * 42 + n_oc_chunk_fused_oh_fused * 42 + kw * 3 + ic_inner + 12] * placeholder_5[kh * 27 + kw * 9 + ic_inner * 3 + oc_block_c] for oc_block_c in T.serial(0, 3): - T.store(conv2d_NCHWc_global, oc_block_c + 15, T.load("float32", conv2d_NCHWc_global, oc_block_c + 15) + T.load("float32", data_pad, kh * 42 + n_oc_chunk_fused_oh_fused * 42 + kw * 3 + ic_inner + 15) * T.load("float32", placeholder_5.data, kh * 27 + kw * 9 + ic_inner * 3 + oc_block_c), True) + conv2d_NCHWc_global[oc_block_c + 15] = conv2d_NCHWc_global[oc_block_c + 15] + data_pad[kh * 42 + n_oc_chunk_fused_oh_fused * 42 + kw * 3 + ic_inner + 15] * placeholder_5[kh * 27 + kw * 9 + ic_inner * 3 + oc_block_c] for oc_block_c in T.serial(0, 3): - T.store(conv2d_NCHWc_global, oc_block_c + 18, T.load("float32", conv2d_NCHWc_global, oc_block_c + 18) + T.load("float32", data_pad, kh * 42 + n_oc_chunk_fused_oh_fused * 42 + kw * 3 + ic_inner + 18) * T.load("float32", placeholder_5.data, kh * 27 + kw * 9 + ic_inner * 3 + oc_block_c), True) + conv2d_NCHWc_global[oc_block_c + 18] = conv2d_NCHWc_global[oc_block_c + 18] + data_pad[kh * 42 + n_oc_chunk_fused_oh_fused * 42 + kw * 3 + ic_inner + 18] * placeholder_5[kh * 27 + kw * 9 + ic_inner * 3 + oc_block_c] for oc_block_c in T.serial(0, 3): - T.store(conv2d_NCHWc_global, oc_block_c + 21, T.load("float32", conv2d_NCHWc_global, oc_block_c + 21) + T.load("float32", data_pad, kh * 42 + n_oc_chunk_fused_oh_fused * 42 + kw * 3 + ic_inner + 21) * T.load("float32", placeholder_5.data, kh * 27 + kw * 9 + ic_inner * 3 + oc_block_c), True) + conv2d_NCHWc_global[oc_block_c + 21] = conv2d_NCHWc_global[oc_block_c + 21] + data_pad[kh * 42 + n_oc_chunk_fused_oh_fused * 42 + kw * 3 + ic_inner + 21] * placeholder_5[kh * 27 + kw * 9 + ic_inner * 3 + oc_block_c] for oc_block_c in T.serial(0, 3): - T.store(conv2d_NCHWc_global, oc_block_c + 24, T.load("float32", conv2d_NCHWc_global, oc_block_c + 24) + T.load("float32", data_pad, kh * 42 + n_oc_chunk_fused_oh_fused * 42 + kw * 3 + ic_inner + 24) * T.load("float32", placeholder_5.data, kh * 27 + kw * 9 + ic_inner * 3 + oc_block_c), True) + conv2d_NCHWc_global[oc_block_c + 24] = conv2d_NCHWc_global[oc_block_c + 24] + data_pad[kh * 42 + n_oc_chunk_fused_oh_fused * 42 + kw * 3 + ic_inner + 24] * placeholder_5[kh * 27 + kw * 9 + ic_inner * 3 + oc_block_c] for oc_block_c in T.serial(0, 3): - T.store(conv2d_NCHWc_global, oc_block_c + 27, T.load("float32", conv2d_NCHWc_global, oc_block_c + 27) + T.load("float32", data_pad, kh * 42 + n_oc_chunk_fused_oh_fused * 42 + kw * 3 + ic_inner + 27) * T.load("float32", placeholder_5.data, kh * 27 + kw * 9 + ic_inner * 3 + oc_block_c), True) + conv2d_NCHWc_global[oc_block_c + 27] = conv2d_NCHWc_global[oc_block_c + 27] + data_pad[kh * 42 + n_oc_chunk_fused_oh_fused * 42 + kw * 3 + ic_inner + 27] * placeholder_5[kh * 27 + kw * 9 + ic_inner * 3 + oc_block_c] for oc_block_c in T.serial(0, 3): - T.store(conv2d_NCHWc_global, oc_block_c + 30, T.load("float32", conv2d_NCHWc_global, oc_block_c + 30) + T.load("float32", data_pad, kh * 42 + n_oc_chunk_fused_oh_fused * 42 + kw * 3 + ic_inner + 30) * T.load("float32", placeholder_5.data, kh * 27 + kw * 9 + ic_inner * 3 + oc_block_c), True) + conv2d_NCHWc_global[oc_block_c + 30] = conv2d_NCHWc_global[oc_block_c + 30] + data_pad[kh * 42 + n_oc_chunk_fused_oh_fused * 42 + kw * 3 + ic_inner + 30] * placeholder_5[kh * 27 + kw * 9 + ic_inner * 3 + oc_block_c] for oc_block_c in T.serial(0, 3): - T.store(conv2d_NCHWc_global, oc_block_c + 33, T.load("float32", conv2d_NCHWc_global, oc_block_c + 33) + T.load("float32", data_pad, kh * 42 + n_oc_chunk_fused_oh_fused * 42 + kw * 3 + ic_inner + 33) * T.load("float32", placeholder_5.data, kh * 27 + kw * 9 + ic_inner * 3 + oc_block_c), True) + conv2d_NCHWc_global[oc_block_c + 33] = conv2d_NCHWc_global[oc_block_c + 33] + data_pad[kh * 42 + n_oc_chunk_fused_oh_fused * 42 + kw * 3 + ic_inner + 33] * placeholder_5[kh * 27 + kw * 9 + ic_inner * 3 + oc_block_c] for ow_inner, oc_block in T.grid(12, 3): - T.store(conv2d_NCHWc_1.data, n_oc_chunk_fused_oh_fused * 36 + ow_inner * 3 + oc_block, T.load("float32", conv2d_NCHWc_global, ow_inner * 3 + oc_block), True) + conv2d_NCHWc_1[n_oc_chunk_fused_oh_fused * 36 + ow_inner * 3 + oc_block] = conv2d_NCHWc_global[ow_inner * 3 + oc_block] @T.prim_func def tvmgen_default_fused_nn_softmax_add_add_multiply_add(placeholder_6: T.handle, placeholder_7: T.handle, placeholder_8: T.handle, placeholder_9: T.handle, placeholder_10: T.handle, T_add: T.handle) -> None: # function attr dict T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "tvmgen_default_fused_nn_softmax_add_add_multiply_add", "tir.noalias": True}) - placeholder_11 = T.match_buffer(placeholder_6, [1, 3, 24, 12], dtype="float32") - placeholder_12 = T.match_buffer(placeholder_7, [1, 3, 24, 12], dtype="float32") - placeholder_13 = T.match_buffer(placeholder_8, [3, 1, 1], dtype="float32") - placeholder_14 = T.match_buffer(placeholder_9, [3, 1, 1], dtype="float32") - placeholder_15 = T.match_buffer(placeholder_10, [3, 1, 1], dtype="float32") - T_add_1 = T.match_buffer(T_add, [1, 3, 24, 12], dtype="float32") + placeholder_11 = T.match_buffer(placeholder_6, [864], dtype="float32") + placeholder_12 = T.match_buffer(placeholder_7, [864], dtype="float32") + placeholder_13 = T.match_buffer(placeholder_8, [3], dtype="float32") + placeholder_14 = T.match_buffer(placeholder_9, [3], dtype="float32") + placeholder_15 = T.match_buffer(placeholder_10, [3], dtype="float32") + T_add_1 = T.match_buffer(T_add, [864], dtype="float32") # body for ax0_ax1_fused_ax2_fused in T.serial(0, 72): - T_softmax_norm = T.allocate([1, 1, 1, 12], "float32", "global") - with T.allocate([1, 1, 1], "float32", "global") as T_softmax_maxelem: - T.store(T_softmax_maxelem, 0, T.float32(-3.4028234663852886e+38), True) + T_softmax_norm = T.allocate([12], "float32", "global") + with T.allocate([1], "float32", "global") as T_softmax_maxelem: + T_softmax_maxelem[0] = T.float32(-3.4028234663852886e+38) for k in T.serial(0, 12): - T.store(T_softmax_maxelem, 0, T.max(T.load("float32", T_softmax_maxelem, 0), T.load("float32", placeholder_11.data, ax0_ax1_fused_ax2_fused * 12 + k)), True) - T_softmax_exp = T.allocate([1, 1, 1, 12], "float32", "global") + T_softmax_maxelem[0] = T.max(T_softmax_maxelem[0], placeholder_11[ax0_ax1_fused_ax2_fused * 12 + k]) + T_softmax_exp = T.allocate([12], "float32", "global") for i3 in T.serial(0, 12): - T.store(T_softmax_exp, i3, T.exp(T.load("float32", placeholder_11.data, ax0_ax1_fused_ax2_fused * 12 + i3) - T.load("float32", T_softmax_maxelem, 0), dtype="float32"), True) - T_softmax_expsum = T.allocate([1, 1, 1], "float32", "global") - T.store(T_softmax_expsum, 0, T.float32(0), True) + T_softmax_exp[i3] = T.exp(placeholder_11[ax0_ax1_fused_ax2_fused * 12 + i3] - T_softmax_maxelem[0], dtype="float32") + T_softmax_expsum = T.allocate([1], "float32", "global") + T_softmax_expsum[0] = T.float32(0) for k in T.serial(0, 12): - T.store(T_softmax_expsum, 0, T.load("float32", T_softmax_expsum, 0) + T.load("float32", T_softmax_exp, k), True) + T_softmax_expsum[0] = T_softmax_expsum[0] + T_softmax_exp[k] for i3 in T.serial(0, 12): - T.store(T_softmax_norm, i3, T.load("float32", T_softmax_exp, i3) / T.load("float32", T_softmax_expsum, 0), True) + T_softmax_norm[i3] = T_softmax_exp[i3] / T_softmax_expsum[0] for ax3 in T.serial(0, 12): - T.store(T_add_1.data, ax0_ax1_fused_ax2_fused * 12 + ax3, (T.load("float32", placeholder_12.data, ax0_ax1_fused_ax2_fused * 12 + ax3) + T.load("float32", T_softmax_norm, ax3) + T.load("float32", placeholder_13.data, T.floordiv(ax0_ax1_fused_ax2_fused, 24))) * T.load("float32", placeholder_14.data, T.floordiv(ax0_ax1_fused_ax2_fused, 24)) + T.load("float32", placeholder_15.data, T.floordiv(ax0_ax1_fused_ax2_fused, 24)), True) + T_add_1[ax0_ax1_fused_ax2_fused * 12 + ax3] = (placeholder_12[ax0_ax1_fused_ax2_fused * 12 + ax3] + T_softmax_norm[ax3] + placeholder_13[T.floordiv(ax0_ax1_fused_ax2_fused, 24)]) * placeholder_14[T.floordiv(ax0_ax1_fused_ax2_fused, 24)] + placeholder_15[T.floordiv(ax0_ax1_fused_ax2_fused, 24)] @T.prim_func def tvmgen_default_fused_nn_contrib_dense_pack_nn_relu(placeholder_16: T.handle, placeholder_17: T.handle, T_relu: T.handle) -> None: # function attr dict T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "tvmgen_default_fused_nn_contrib_dense_pack_nn_relu", "tir.noalias": True}) - placeholder_18 = T.match_buffer(placeholder_16, [72, 12], dtype="float32") - placeholder_19 = T.match_buffer(placeholder_17, [2, 12, 6], dtype="float32") - T_relu_1 = T.match_buffer(T_relu, [72, 12], dtype="float32") + placeholder_18 = T.match_buffer(placeholder_16, [864], dtype="float32") + placeholder_19 = T.match_buffer(placeholder_17, [144], dtype="float32") + T_relu_1 = T.match_buffer(T_relu, [864], dtype="float32") # body for ax1_outer_ax0_outer_fused in T.serial(0, 18): - compute = T.allocate([8, 6], "float32", "global") - with T.allocate([8, 6], "float32", "global") as compute_global: + compute = T.allocate([48], "float32", "global") + with T.allocate([48], "float32", "global") as compute_global: for x_c_init in T.serial(0, 6): - T.store(compute_global, x_c_init, T.float32(0), True) + compute_global[x_c_init] = T.float32(0) for x_c_init in T.serial(0, 6): - T.store(compute_global, x_c_init + 6, T.float32(0), True) + compute_global[x_c_init + 6] = T.float32(0) for x_c_init in T.serial(0, 6): - T.store(compute_global, x_c_init + 12, T.float32(0), True) + compute_global[x_c_init + 12] = T.float32(0) for x_c_init in T.serial(0, 6): - T.store(compute_global, x_c_init + 18, T.float32(0), True) + compute_global[x_c_init + 18] = T.float32(0) for x_c_init in T.serial(0, 6): - T.store(compute_global, x_c_init + 24, T.float32(0), True) + compute_global[x_c_init + 24] = T.float32(0) for x_c_init in T.serial(0, 6): - T.store(compute_global, x_c_init + 30, T.float32(0), True) + compute_global[x_c_init + 30] = T.float32(0) for x_c_init in T.serial(0, 6): - T.store(compute_global, x_c_init + 36, T.float32(0), True) + compute_global[x_c_init + 36] = T.float32(0) for x_c_init in T.serial(0, 6): - T.store(compute_global, x_c_init + 42, T.float32(0), True) + compute_global[x_c_init + 42] = T.float32(0) for k_outer in T.serial(0, 12): for x_c in T.serial(0, 6): - T.store(compute_global, x_c, T.load("float32", compute_global, x_c) + T.load("float32", placeholder_18.data, T.floormod(ax1_outer_ax0_outer_fused, 9) * 96 + k_outer) * T.load("float32", placeholder_19.data, T.floordiv(ax1_outer_ax0_outer_fused, 9) * 72 + k_outer * 6 + x_c), True) + compute_global[x_c] = compute_global[x_c] + placeholder_18[T.floormod(ax1_outer_ax0_outer_fused, 9) * 96 + k_outer] * placeholder_19[T.floordiv(ax1_outer_ax0_outer_fused, 9) * 72 + k_outer * 6 + x_c] for x_c in T.serial(0, 6): - T.store(compute_global, x_c + 6, T.load("float32", compute_global, x_c + 6) + T.load("float32", placeholder_18.data, T.floormod(ax1_outer_ax0_outer_fused, 9) * 96 + k_outer + 12) * T.load("float32", placeholder_19.data, T.floordiv(ax1_outer_ax0_outer_fused, 9) * 72 + k_outer * 6 + x_c), True) + compute_global[x_c + 6] = compute_global[x_c + 6] + placeholder_18[T.floormod(ax1_outer_ax0_outer_fused, 9) * 96 + k_outer + 12] * placeholder_19[T.floordiv(ax1_outer_ax0_outer_fused, 9) * 72 + k_outer * 6 + x_c] for x_c in T.serial(0, 6): - T.store(compute_global, x_c + 12, T.load("float32", compute_global, x_c + 12) + T.load("float32", placeholder_18.data, T.floormod(ax1_outer_ax0_outer_fused, 9) * 96 + k_outer + 24) * T.load("float32", placeholder_19.data, T.floordiv(ax1_outer_ax0_outer_fused, 9) * 72 + k_outer * 6 + x_c), True) + compute_global[x_c + 12] = compute_global[x_c + 12] + placeholder_18[T.floormod(ax1_outer_ax0_outer_fused, 9) * 96 + k_outer + 24] * placeholder_19[T.floordiv(ax1_outer_ax0_outer_fused, 9) * 72 + k_outer * 6 + x_c] for x_c in T.serial(0, 6): - T.store(compute_global, x_c + 18, T.load("float32", compute_global, x_c + 18) + T.load("float32", placeholder_18.data, T.floormod(ax1_outer_ax0_outer_fused, 9) * 96 + k_outer + 36) * T.load("float32", placeholder_19.data, T.floordiv(ax1_outer_ax0_outer_fused, 9) * 72 + k_outer * 6 + x_c), True) + compute_global[x_c + 18] = compute_global[x_c + 18] + placeholder_18[T.floormod(ax1_outer_ax0_outer_fused, 9) * 96 + k_outer + 36] * placeholder_19[T.floordiv(ax1_outer_ax0_outer_fused, 9) * 72 + k_outer * 6 + x_c] for x_c in T.serial(0, 6): - T.store(compute_global, x_c + 24, T.load("float32", compute_global, x_c + 24) + T.load("float32", placeholder_18.data, T.floormod(ax1_outer_ax0_outer_fused, 9) * 96 + k_outer + 48) * T.load("float32", placeholder_19.data, T.floordiv(ax1_outer_ax0_outer_fused, 9) * 72 + k_outer * 6 + x_c), True) + compute_global[x_c + 24] = compute_global[x_c + 24] + placeholder_18[T.floormod(ax1_outer_ax0_outer_fused, 9) * 96 + k_outer + 48] * placeholder_19[T.floordiv(ax1_outer_ax0_outer_fused, 9) * 72 + k_outer * 6 + x_c] for x_c in T.serial(0, 6): - T.store(compute_global, x_c + 30, T.load("float32", compute_global, x_c + 30) + T.load("float32", placeholder_18.data, T.floormod(ax1_outer_ax0_outer_fused, 9) * 96 + k_outer + 60) * T.load("float32", placeholder_19.data, T.floordiv(ax1_outer_ax0_outer_fused, 9) * 72 + k_outer * 6 + x_c), True) + compute_global[x_c + 30] = compute_global[x_c + 30] + placeholder_18[T.floormod(ax1_outer_ax0_outer_fused, 9) * 96 + k_outer + 60] * placeholder_19[T.floordiv(ax1_outer_ax0_outer_fused, 9) * 72 + k_outer * 6 + x_c] for x_c in T.serial(0, 6): - T.store(compute_global, x_c + 36, T.load("float32", compute_global, x_c + 36) + T.load("float32", placeholder_18.data, T.floormod(ax1_outer_ax0_outer_fused, 9) * 96 + k_outer + 72) * T.load("float32", placeholder_19.data, T.floordiv(ax1_outer_ax0_outer_fused, 9) * 72 + k_outer * 6 + x_c), True) + compute_global[x_c + 36] = compute_global[x_c + 36] + placeholder_18[T.floormod(ax1_outer_ax0_outer_fused, 9) * 96 + k_outer + 72] * placeholder_19[T.floordiv(ax1_outer_ax0_outer_fused, 9) * 72 + k_outer * 6 + x_c] for x_c in T.serial(0, 6): - T.store(compute_global, x_c + 42, T.load("float32", compute_global, x_c + 42) + T.load("float32", placeholder_18.data, T.floormod(ax1_outer_ax0_outer_fused, 9) * 96 + k_outer + 84) * T.load("float32", placeholder_19.data, T.floordiv(ax1_outer_ax0_outer_fused, 9) * 72 + k_outer * 6 + x_c), True) + compute_global[x_c + 42] = compute_global[x_c + 42] + placeholder_18[T.floormod(ax1_outer_ax0_outer_fused, 9) * 96 + k_outer + 84] * placeholder_19[T.floordiv(ax1_outer_ax0_outer_fused, 9) * 72 + k_outer * 6 + x_c] for x_inner_inner in T.serial(0, 6): - T.store(compute, x_inner_inner, T.load("float32", compute_global, x_inner_inner), True) + compute[x_inner_inner] = compute_global[x_inner_inner] for x_inner_inner in T.serial(0, 6): - T.store(compute, x_inner_inner + 6, T.load("float32", compute_global, x_inner_inner + 6), True) + compute[x_inner_inner + 6] = compute_global[x_inner_inner + 6] for x_inner_inner in T.serial(0, 6): - T.store(compute, x_inner_inner + 12, T.load("float32", compute_global, x_inner_inner + 12), True) + compute[x_inner_inner + 12] = compute_global[x_inner_inner + 12] for x_inner_inner in T.serial(0, 6): - T.store(compute, x_inner_inner + 18, T.load("float32", compute_global, x_inner_inner + 18), True) + compute[x_inner_inner + 18] = compute_global[x_inner_inner + 18] for x_inner_inner in T.serial(0, 6): - T.store(compute, x_inner_inner + 24, T.load("float32", compute_global, x_inner_inner + 24), True) + compute[x_inner_inner + 24] = compute_global[x_inner_inner + 24] for x_inner_inner in T.serial(0, 6): - T.store(compute, x_inner_inner + 30, T.load("float32", compute_global, x_inner_inner + 30), True) + compute[x_inner_inner + 30] = compute_global[x_inner_inner + 30] for x_inner_inner in T.serial(0, 6): - T.store(compute, x_inner_inner + 36, T.load("float32", compute_global, x_inner_inner + 36), True) + compute[x_inner_inner + 36] = compute_global[x_inner_inner + 36] for x_inner_inner in T.serial(0, 6): - T.store(compute, x_inner_inner + 42, T.load("float32", compute_global, x_inner_inner + 42), True) + compute[x_inner_inner + 42] = compute_global[x_inner_inner + 42] for ax0_inner_inner, ax1_inner_inner in T.grid(8, 6): - T.store(T_relu_1.data, T.floormod(ax1_outer_ax0_outer_fused, 9) * 96 + ax0_inner_inner * 12 + T.floordiv(ax1_outer_ax0_outer_fused, 9) * 6 + ax1_inner_inner, T.max(T.load("float32", compute, ax0_inner_inner * 6 + ax1_inner_inner), T.float32(0)), True) + T_relu_1[T.floormod(ax1_outer_ax0_outer_fused, 9) * 96 + ax0_inner_inner * 12 + T.floordiv(ax1_outer_ax0_outer_fused, 9) * 6 + ax1_inner_inner] = T.max(compute[ax0_inner_inner * 6 + ax1_inner_inner], T.float32(0)) @T.prim_func def tvmgen_default_fused_reshape_1(placeholder_20: T.handle, T_reshape: T.handle) -> None: # function attr dict T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "tvmgen_default_fused_reshape_1", "tir.noalias": True}) - placeholder_21 = T.match_buffer(placeholder_20, [1, 3, 24, 12], dtype="float32") - T_reshape_1 = T.match_buffer(T_reshape, [72, 12], dtype="float32") + placeholder_21 = T.match_buffer(placeholder_20, [864], dtype="float32") + T_reshape_1 = T.match_buffer(T_reshape, [864], dtype="float32") # body for ax0, ax1_inner in T.grid(72, 12): - T.store(T_reshape_1.data, ax0 * 12 + ax1_inner, T.load("float32", placeholder_21.data, ax0 * 12 + ax1_inner), True) + T_reshape_1[ax0 * 12 + ax1_inner] = placeholder_21[ax0 * 12 + ax1_inner] @T.prim_func def tvmgen_default_fused_layout_transform(placeholder_22: T.handle, T_layout_trans_2: T.handle) -> None: # function attr dict T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "tvmgen_default_fused_layout_transform", "tir.noalias": True}) - placeholder_23 = T.match_buffer(placeholder_22, [1, 1, 24, 12, 3], dtype="float32") - T_layout_trans_3 = T.match_buffer(T_layout_trans_2, [1, 3, 24, 12], dtype="float32") + placeholder_23 = T.match_buffer(placeholder_22, [864], dtype="float32") + T_layout_trans_3 = T.match_buffer(T_layout_trans_2, [864], dtype="float32") # body for ax0_ax1_fused, ax2, ax3_inner in T.grid(3, 24, 12): - T.store(T_layout_trans_3.data, ax0_ax1_fused * 288 + ax2 * 12 + ax3_inner, T.load("float32", placeholder_23.data, ax2 * 36 + ax3_inner * 3 + ax0_ax1_fused), True) + T_layout_trans_3[ax0_ax1_fused * 288 + ax2 * 12 + ax3_inner] = placeholder_23[ax2 * 36 + ax3_inner * 3 + ax0_ax1_fused] @T.prim_func def tvmgen_default_fused_reshape(placeholder_24: T.handle, T_reshape_2: T.handle) -> None: # function attr dict T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "tvmgen_default_fused_reshape", "tir.noalias": True}) - placeholder_25 = T.match_buffer(placeholder_24, [72, 12], dtype="float32") - T_reshape_3 = T.match_buffer(T_reshape_2, [1, 3, 24, 12], dtype="float32") + placeholder_25 = T.match_buffer(placeholder_24, [864], dtype="float32") + T_reshape_3 = T.match_buffer(T_reshape_2, [864], dtype="float32") # body for ax0_ax1_fused, ax2, ax3_inner in T.grid(3, 24, 12): - T.store(T_reshape_3.data, ax0_ax1_fused * 288 + ax2 * 12 + ax3_inner, T.load("float32", placeholder_25.data, ax0_ax1_fused * 288 + ax2 * 12 + ax3_inner), True) + T_reshape_3[ax0_ax1_fused * 288 + ax2 * 12 + ax3_inner] = placeholder_25[ax0_ax1_fused * 288 + ax2 * 12 + ax3_inner] @T.prim_func def tvmgen_default_fused_nn_softmax_add(placeholder_26: T.handle, placeholder_27: T.handle, T_add_2: T.handle) -> None: # function attr dict T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "tvmgen_default_fused_nn_softmax_add", "tir.noalias": True}) - placeholder_28 = T.match_buffer(placeholder_26, [1, 3, 24, 12], dtype="float32") - placeholder_29 = T.match_buffer(placeholder_27, [1, 3, 24, 12], dtype="float32") - T_add_3 = T.match_buffer(T_add_2, [1, 3, 24, 12], dtype="float32") + placeholder_28 = T.match_buffer(placeholder_26, [864], dtype="float32") + placeholder_29 = T.match_buffer(placeholder_27, [864], dtype="float32") + T_add_3 = T.match_buffer(T_add_2, [864], dtype="float32") # body for ax0_ax1_fused_ax2_fused in T.serial(0, 72): - T_softmax_norm = T.allocate([1, 1, 1, 12], "float32", "global") - with T.allocate([1, 1, 1], "float32", "global") as T_softmax_maxelem: - T.store(T_softmax_maxelem, 0, T.float32(-3.4028234663852886e+38), True) + T_softmax_norm = T.allocate([12], "float32", "global") + with T.allocate([1], "float32", "global") as T_softmax_maxelem: + T_softmax_maxelem[0] = T.float32(-3.4028234663852886e+38) for k in T.serial(0, 12): - T.store(T_softmax_maxelem, 0, T.max(T.load("float32", T_softmax_maxelem, 0), T.load("float32", placeholder_28.data, ax0_ax1_fused_ax2_fused * 12 + k)), True) - T_softmax_exp = T.allocate([1, 1, 1, 12], "float32", "global") + T_softmax_maxelem[0] = T.max(T_softmax_maxelem[0], placeholder_28[ax0_ax1_fused_ax2_fused * 12 + k]) + T_softmax_exp = T.allocate([12], "float32", "global") for i3 in T.serial(0, 12): - T.store(T_softmax_exp, i3, T.exp(T.load("float32", placeholder_28.data, ax0_ax1_fused_ax2_fused * 12 + i3) - T.load("float32", T_softmax_maxelem, 0), dtype="float32"), True) - T_softmax_expsum = T.allocate([1, 1, 1], "float32", "global") - T.store(T_softmax_expsum, 0, T.float32(0), True) + T_softmax_exp[i3] = T.exp(placeholder_28[ax0_ax1_fused_ax2_fused * 12 + i3] - T_softmax_maxelem[0], dtype="float32") + T_softmax_expsum = T.allocate([1], "float32", "global") + T_softmax_expsum[0] = T.float32(0) for k in T.serial(0, 12): - T.store(T_softmax_expsum, 0, T.load("float32", T_softmax_expsum, 0) + T.load("float32", T_softmax_exp, k), True) + T_softmax_expsum[0] = T_softmax_expsum[0] + T_softmax_exp[k] for i3 in T.serial(0, 12): - T.store(T_softmax_norm, i3, T.load("float32", T_softmax_exp, i3) / T.load("float32", T_softmax_expsum, 0), True) + T_softmax_norm[i3] = T_softmax_exp[i3] / T_softmax_expsum[0] for ax3 in T.serial(0, 12): - T.store(T_add_3.data, ax0_ax1_fused_ax2_fused * 12 + ax3, T.load("float32", placeholder_29.data, ax0_ax1_fused_ax2_fused * 12 + ax3) + T.load("float32", T_softmax_norm, ax3), True) + T_add_3[ax0_ax1_fused_ax2_fused * 12 + ax3] = placeholder_29[ax0_ax1_fused_ax2_fused * 12 + ax3] + T_softmax_norm[ax3] @T.prim_func def run_model(data: T.handle, output: T.handle) -> None: # function attr dict T.func_attr({"global_symbol": "tvmgen_default_run_model", "runner_function": True}) - data_buffer = T.match_buffer(data, [1, 3, 24, 12], dtype="float32", align=16) - output_buffer = T.match_buffer(output, [1, 3, 24, 12], dtype="float32", align=16) + data_buffer = T.match_buffer(data, [864], dtype="float32", align=16) + output_buffer = T.match_buffer(output, [864], dtype="float32", align=16) # body sid_11 = T.allocate([3456], "int8", "global.workspace") sid_5 = T.allocate([3456], "int8", "global.workspace") @@ -1346,20 +1346,20 @@ def run_model(data: T.handle, output: T.handle) -> None: sid_18 = T.allocate([3456], "int8", "global.workspace") sid_19 = T.allocate([3456], "int8", "global.workspace") sid_20 = T.allocate([3456], "int8", "global.workspace") - T.evaluate(T.tvm_call_cpacked("tvmgen_default_fused_layout_transform_1", data_buffer.data, sid_8, dtype="int32")) - T.evaluate(T.tvm_call_cpacked("tvmgen_default_fused_nn_contrib_conv2d_NCHWc", sid_8, T.cast(T.lookup_param("p0", dtype="handle"), "handle"), sid_7, dtype="int32")) - T.evaluate(T.tvm_call_cpacked("tvmgen_default_fused_layout_transform", sid_7, sid_6, dtype="int32")) - T.evaluate(T.tvm_call_cpacked("tvmgen_default_fused_reshape_1", data_buffer.data, sid_12, dtype="int32")) - T.evaluate(T.tvm_call_cpacked("tvmgen_default_fused_nn_contrib_dense_pack_nn_relu", sid_12, T.cast(T.lookup_param("p1", dtype="handle"), "handle"), sid_11, dtype="int32")) - T.evaluate(T.tvm_call_cpacked("tvmgen_default_fused_reshape", sid_11, sid_10, dtype="int32")) - T.evaluate(T.tvm_call_cpacked("tvmgen_default_fused_nn_softmax_add_add_multiply_add", sid_6, sid_10, T.cast(T.lookup_param("p2", dtype="handle"), "handle"), T.cast(T.lookup_param("p3", dtype="handle"), "handle"), T.cast(T.lookup_param("p4", dtype="handle"), "handle"), sid_5, dtype="int32")) - T.evaluate(T.tvm_call_cpacked("tvmgen_default_fused_layout_transform_1", sid_5, sid_4, dtype="int32")) - T.evaluate(T.tvm_call_cpacked("tvmgen_default_fused_nn_contrib_conv2d_NCHWc", sid_4, T.cast(T.lookup_param("p5", dtype="handle"), "handle"), sid_3, dtype="int32")) - T.evaluate(T.tvm_call_cpacked("tvmgen_default_fused_layout_transform", sid_3, sid_2, dtype="int32")) - T.evaluate(T.tvm_call_cpacked("tvmgen_default_fused_reshape_1", sid_5, sid_20, dtype="int32")) - T.evaluate(T.tvm_call_cpacked("tvmgen_default_fused_nn_contrib_dense_pack_nn_relu", sid_20, T.cast(T.lookup_param("p6", dtype="handle"), "handle"), sid_19, dtype="int32")) - T.evaluate(T.tvm_call_cpacked("tvmgen_default_fused_reshape", sid_19, sid_18, dtype="int32")) - T.evaluate(T.tvm_call_cpacked("tvmgen_default_fused_nn_softmax_add", sid_2, sid_18, output_buffer.data, dtype="int32")) + T.evaluate(T.tvm_call_cpacked("tvmgen_default_fused_layout_transform_1", data_buffer.data, sid_8.data, dtype="int32")) + T.evaluate(T.tvm_call_cpacked("tvmgen_default_fused_nn_contrib_conv2d_NCHWc", sid_8.data, T.cast(T.lookup_param("p0", dtype="handle"), "handle"), sid_7.data, dtype="int32")) + T.evaluate(T.tvm_call_cpacked("tvmgen_default_fused_layout_transform", sid_7.data, sid_6.data, dtype="int32")) + T.evaluate(T.tvm_call_cpacked("tvmgen_default_fused_reshape_1", data_buffer.data, sid_12.data, dtype="int32")) + T.evaluate(T.tvm_call_cpacked("tvmgen_default_fused_nn_contrib_dense_pack_nn_relu", sid_12.data, T.cast(T.lookup_param("p1", dtype="handle"), "handle"), sid_11.data, dtype="int32")) + T.evaluate(T.tvm_call_cpacked("tvmgen_default_fused_reshape", sid_11.data, sid_10.data, dtype="int32")) + T.evaluate(T.tvm_call_cpacked("tvmgen_default_fused_nn_softmax_add_add_multiply_add", sid_6.data, sid_10.data, T.cast(T.lookup_param("p2", dtype="handle"), "handle"), T.cast(T.lookup_param("p3", dtype="handle"), "handle"), T.cast(T.lookup_param("p4", dtype="handle"), "handle"), sid_5.data, dtype="int32")) + T.evaluate(T.tvm_call_cpacked("tvmgen_default_fused_layout_transform_1", sid_5.data, sid_4.data, dtype="int32")) + T.evaluate(T.tvm_call_cpacked("tvmgen_default_fused_nn_contrib_conv2d_NCHWc", sid_4.data, T.cast(T.lookup_param("p5", dtype="handle"), "handle"), sid_3.data, dtype="int32")) + T.evaluate(T.tvm_call_cpacked("tvmgen_default_fused_layout_transform", sid_3.data, sid_2.data, dtype="int32")) + T.evaluate(T.tvm_call_cpacked("tvmgen_default_fused_reshape_1", sid_5.data, sid_20.data, dtype="int32")) + T.evaluate(T.tvm_call_cpacked("tvmgen_default_fused_nn_contrib_dense_pack_nn_relu", sid_20.data, T.cast(T.lookup_param("p6", dtype="handle"), "handle"), sid_19.data, dtype="int32")) + T.evaluate(T.tvm_call_cpacked("tvmgen_default_fused_reshape", sid_19.data, sid_18.data, dtype="int32")) + T.evaluate(T.tvm_call_cpacked("tvmgen_default_fused_nn_softmax_add", sid_2.data, sid_18.data, output_buffer.data, dtype="int32")) # fmt: on diff --git a/tests/python/unittest/test_tir_usmp_transform_convert_pool_allocations_to_offsets.py b/tests/python/unittest/test_tir_usmp_transform_convert_pool_allocations_to_offsets.py index 07e31a989874..4ed02615cd44 100644 --- a/tests/python/unittest/test_tir_usmp_transform_convert_pool_allocations_to_offsets.py +++ b/tests/python/unittest/test_tir_usmp_transform_convert_pool_allocations_to_offsets.py @@ -73,53 +73,53 @@ class LinearStructure: def tvmgen_default_fused_cast_subtract(placeholder_2: T.handle, placeholder_3: T.handle, T_subtract: T.handle) -> None: # function attr dict T.func_attr({"global_symbol": "tvmgen_default_fused_cast_subtract", "tir.noalias": True}) - placeholder_4 = T.match_buffer(placeholder_2, [1, 224, 224, 3], dtype="uint8", elem_offset=0, align=128, offset_factor=1) - placeholder_5 = T.match_buffer(placeholder_3, [], dtype="int16", elem_offset=0, align=128, offset_factor=1) - T_subtract_1 = T.match_buffer(T_subtract, [1, 224, 224, 3], dtype="int16", elem_offset=0, align=128, offset_factor=1) + placeholder_4 = T.match_buffer(placeholder_2, [150528], dtype="uint8", elem_offset=0, align=128, offset_factor=1) + placeholder_5 = T.match_buffer(placeholder_3, [1], dtype="int16", elem_offset=0, align=128, offset_factor=1) + T_subtract_1 = T.match_buffer(T_subtract, [452], dtype="int16", elem_offset=0, align=128, offset_factor=1) # body for ax0_ax1_fused_1 in T.serial(0, 224): for ax2_1, ax3_inner_1 in T.grid(224, 3): - T.store(T_subtract_1.data, (((ax0_ax1_fused_1*672) + (ax2_1*3)) + ax3_inner_1), (T.cast(T.load("uint8", placeholder_4.data, (((ax0_ax1_fused_1*672) + (ax2_1*3)) + ax3_inner_1)), "int16") - T.load("int16", placeholder_5.data, 0)), True) + T_subtract_1[(((ax0_ax1_fused_1*672) + (ax2_1*3)) + ax3_inner_1)] = (T.cast(placeholder_4[(((ax0_ax1_fused_1*672) + (ax2_1*3)) + ax3_inner_1)], "int16") - placeholder_5[0]) @T.prim_func def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast(placeholder_62: T.handle, placeholder_63: T.handle, placeholder_64: T.handle, T_cast_20: T.handle) -> None: # function attr dict T.func_attr({"global_symbol": "tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast", "tir.noalias": True}) - placeholder_65 = T.match_buffer(placeholder_62, [1, 224, 224, 3], dtype="int16", elem_offset=0, align=128, offset_factor=1) - placeholder_66 = T.match_buffer(placeholder_63, [7, 7, 3, 64], dtype="int16", elem_offset=0, align=128, offset_factor=1) - placeholder_67 = T.match_buffer(placeholder_64, [1, 1, 1, 64], dtype="int32", elem_offset=0, align=128, offset_factor=1) - T_cast_21 = T.match_buffer(T_cast_20, [1, 112, 112, 64], dtype="uint8", elem_offset=0, align=128, offset_factor=1) + placeholder_65 = T.match_buffer(placeholder_62, [150528], dtype="int16", elem_offset=0, align=128, offset_factor=1) + placeholder_66 = T.match_buffer(placeholder_63, [9408], dtype="int16", elem_offset=0, align=128, offset_factor=1) + placeholder_67 = T.match_buffer(placeholder_64, [64], dtype="int32", elem_offset=0, align=128, offset_factor=1) + T_cast_21 = T.match_buffer(T_cast_20, [289], dtype="uint8", elem_offset=0, align=128, offset_factor=1) # body PaddedInput_7 = T.allocate([157323], "int16", "global") for i0_i1_fused_7 in T.serial(0, 229): for i2_7, i3_7 in T.grid(229, 3): - T.store(PaddedInput_7, (((i0_i1_fused_7*687) + (i2_7*3)) + i3_7), T.if_then_else(((((2 <= i0_i1_fused_7) and (i0_i1_fused_7 < 226)) and (2 <= i2_7)) and (i2_7 < 226)), T.load("int16", placeholder_65.data, ((((i0_i1_fused_7*672) + (i2_7*3)) + i3_7) - 1350)), T.int16(0), dtype="int16"), True) + PaddedInput_7[(((i0_i1_fused_7*687) + (i2_7*3)) + i3_7)] = T.if_then_else(((((2 <= i0_i1_fused_7) and (i0_i1_fused_7 < 226)) and (2 <= i2_7)) and (i2_7 < 226)), placeholder_65[((((i0_i1_fused_7*672) + (i2_7*3)) + i3_7) - 1350)], T.int16(0), dtype="int16") for ax0_ax1_fused_ax2_fused_7 in T.serial(0, 12544): Conv2dOutput_7 = T.allocate([64], "int32", "global") for ff_3 in T.serial(0, 64): - T.store(Conv2dOutput_7, ff_3, 0, True) + Conv2dOutput_7[ff_3] = 0 for ry_2, rx_2, rc_7 in T.grid(7, 7, 3): - T.store(Conv2dOutput_7, ff_3, (T.load("int32", Conv2dOutput_7, ff_3) + (T.cast(T.load("int16", PaddedInput_7, (((((T.floordiv(ax0_ax1_fused_ax2_fused_7, 112)*1374) + (ry_2*687)) + (T.floormod(ax0_ax1_fused_ax2_fused_7, 112)*6)) + (rx_2*3)) + rc_7)), "int32")*T.cast(T.load("int16", placeholder_66.data, ((((ry_2*1344) + (rx_2*192)) + (rc_7*64)) + ff_3)), "int32"))), True) + Conv2dOutput_7[ff_3] = (Conv2dOutput_7[ff_3] + (T.cast(PaddedInput_7[(((((T.floordiv(ax0_ax1_fused_ax2_fused_7, 112)*1374) + (ry_2*687)) + (T.floormod(ax0_ax1_fused_ax2_fused_7, 112)*6)) + (rx_2*3)) + rc_7)], "int32")*T.cast(placeholder_66[((((ry_2*1344) + (rx_2*192)) + (rc_7*64)) + ff_3)], "int32"))) for ax3_inner_7 in T.serial(0, 64): - T.store(T_cast_21.data, ((ax0_ax1_fused_ax2_fused_7*64) + ax3_inner_7), T.cast(T.max(T.min(T.q_multiply_shift((T.load("int32", Conv2dOutput_7, ax3_inner_7) + T.load("int32", placeholder_67.data, ax3_inner_7)), 1939887962, 31, -9, dtype="int32"), 255), 0), "uint8"), True) + T_cast_21[((ax0_ax1_fused_ax2_fused_7*64) + ax3_inner_7)] = T.cast(T.max(T.min(T.q_multiply_shift((Conv2dOutput_7[ax3_inner_7] + placeholder_67[ax3_inner_7]), 1939887962, 31, -9, dtype="int32"), 255), 0), "uint8") @T.prim_func def tvmgen_default_fused_nn_max_pool2d_cast(placeholder_28: T.handle, T_cast_6: T.handle) -> None: # function attr dict T.func_attr({"global_symbol": "tvmgen_default_fused_nn_max_pool2d_cast", "tir.noalias": True}) - placeholder_29 = T.match_buffer(placeholder_28, [1, 112, 112, 64], dtype="uint8", elem_offset=0, align=128, offset_factor=1) - T_cast_7 = T.match_buffer(T_cast_6, [1, 56, 56, 64], dtype="int16", elem_offset=0, align=128, offset_factor=1) + placeholder_29 = T.match_buffer(placeholder_28, [802816], dtype="uint8", elem_offset=0, align=128, offset_factor=1) + T_cast_7 = T.match_buffer(T_cast_6, [177], dtype="int16", elem_offset=0, align=128, offset_factor=1) # body tensor_2 = T.allocate([200704], "uint8", "global") for ax0_ax1_fused_4 in T.serial(0, 56): for ax2_4 in T.serial(0, 56): for ax3_init in T.serial(0, 64): - T.store(tensor_2, (((ax0_ax1_fused_4*3584) + (ax2_4*64)) + ax3_init), T.uint8(0), True) + tensor_2[(((ax0_ax1_fused_4*3584) + (ax2_4*64)) + ax3_init)] = T.uint8(0) for rv0_rv1_fused_1, ax3_2 in T.grid(9, 64): - T.store(tensor_2, (((ax0_ax1_fused_4*3584) + (ax2_4*64)) + ax3_2), T.max(T.load("uint8", tensor_2, (((ax0_ax1_fused_4*3584) + (ax2_4*64)) + ax3_2)), T.if_then_else(((((ax0_ax1_fused_4*2) + T.floordiv(rv0_rv1_fused_1, 3)) < 112) and (((ax2_4*2) + T.floormod(rv0_rv1_fused_1, 3)) < 112)), T.load("uint8", placeholder_29.data, (((((ax0_ax1_fused_4*14336) + (T.floordiv(rv0_rv1_fused_1, 3)*7168)) + (ax2_4*128)) + (T.floormod(rv0_rv1_fused_1, 3)*64)) + ax3_2)), T.uint8(0), dtype="uint8")), True) + tensor_2[(((ax0_ax1_fused_4*3584) + (ax2_4*64)) + ax3_2)] = T.max(tensor_2[(((ax0_ax1_fused_4*3584) + (ax2_4*64)) + ax3_2)], T.if_then_else(((((ax0_ax1_fused_4*2) + T.floordiv(rv0_rv1_fused_1, 3)) < 112) and (((ax2_4*2) + T.floormod(rv0_rv1_fused_1, 3)) < 112)), placeholder_29[(((((ax0_ax1_fused_4*14336) + (T.floordiv(rv0_rv1_fused_1, 3)*7168)) + (ax2_4*128)) + (T.floormod(rv0_rv1_fused_1, 3)*64)) + ax3_2)], T.uint8(0), dtype="uint8")) for ax0_ax1_fused_5 in T.serial(0, 56): for ax2_5, ax3_3 in T.grid(56, 64): - T.store(T_cast_7.data, (((ax0_ax1_fused_5*3584) + (ax2_5*64)) + ax3_3), T.cast(T.load("uint8", tensor_2, (((ax0_ax1_fused_5*3584) + (ax2_5*64)) + ax3_3)), "int16"), True) + T_cast_7[(((ax0_ax1_fused_5*3584) + (ax2_5*64)) + ax3_3)] = T.cast(tensor_2[(((ax0_ax1_fused_5*3584) + (ax2_5*64)) + ax3_3)], "int16") @T.prim_func def __tvm_main__(input: T.handle, output: T.handle) -> None: @@ -130,9 +130,9 @@ def __tvm_main__(input: T.handle, output: T.handle) -> None: T.attr("default", "device_type", 1) sid_9 = T.allocate([301056], "int8", "global") sid_8 = T.allocate([802816], "int8", "global") - T.evaluate(T.call_extern("tvmgen_default_fused_cast_subtract", input, T.lookup_param("p0", dtype="handle"), sid_9, dtype="int32")) - T.evaluate(T.call_extern("tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast", sid_9, T.lookup_param("p1", dtype="handle"), T.lookup_param("p2", dtype="handle"), sid_8, dtype="int32")) - T.evaluate(T.call_extern("tvmgen_default_fused_nn_max_pool2d_cast", sid_8, output, dtype="int32")) + T.evaluate(T.call_extern("tvmgen_default_fused_cast_subtract", input, T.lookup_param("p0", dtype="handle"), sid_9.data, dtype="int32")) + T.evaluate(T.call_extern("tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast", sid_9.data, T.lookup_param("p1", dtype="handle"), T.lookup_param("p2", dtype="handle"), sid_8.data, dtype="int32")) + T.evaluate(T.call_extern("tvmgen_default_fused_nn_max_pool2d_cast", sid_8.data, output, dtype="int32")) # fmt: on @@ -140,65 +140,68 @@ def __tvm_main__(input: T.handle, output: T.handle) -> None: @tvm.script.ir_module class LinearStructurePlanned: @T.prim_func - def __tvm_main__(input: T.handle, fast_memory_0_var: T.handle, slow_memory_1_var: T.handle, output: T.handle) -> None: + def __tvm_main__(input: T.handle, fast_memory_0_var: T.Ptr[T.uint8], slow_memory_1_var: T.Ptr[T.uint8], output: T.handle) -> None: fast_memory_0_buffer_var = T.match_buffer(fast_memory_0_var, [200704], dtype="uint8", strides=[1], elem_offset=1, align=16) slow_memory_1_buffer_var = T.match_buffer(slow_memory_1_var, [1418528], dtype="uint8", strides=[1], elem_offset=1, align=16) # body T.attr("default", "device_id", 0) T.attr("default", "device_type", 1) - sid_9_let: T.handle = T.address_of(T.load("uint8", slow_memory_1_buffer_var.data, 1117472), dtype="handle") - sid_8_let: T.handle = T.address_of(T.load("uint8", slow_memory_1_buffer_var.data, 0), dtype="handle") + sid_9_let: T.Ptr[T.int8] = T.address_of(slow_memory_1_buffer_var[1117472], dtype="handle") + sid_8_let: T.Ptr[T.int8] = T.address_of(slow_memory_1_buffer_var[0], dtype="handle") T.evaluate(T.call_extern("tvmgen_default_fused_cast_subtract", input, T.lookup_param("p0", dtype="handle"), sid_9_let, fast_memory_0_buffer_var.data, slow_memory_1_buffer_var.data, dtype="int32")) T.evaluate(T.call_extern("tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast", sid_9_let, T.lookup_param("p1", dtype="handle"), T.lookup_param("p2", dtype="handle"), sid_8_let, fast_memory_0_buffer_var.data, slow_memory_1_buffer_var.data, dtype="int32")) T.evaluate(T.call_extern("tvmgen_default_fused_nn_max_pool2d_cast", sid_8_let, output, fast_memory_0_buffer_var.data, slow_memory_1_buffer_var.data, dtype="int32")) @T.prim_func - def tvmgen_default_fused_nn_max_pool2d_cast(placeholder_28: T.handle, T_cast_6: T.handle, fast_memory_6_var: T.handle, slow_memory_7_var: T.handle) -> None: - placeholder_29 = T.match_buffer(placeholder_28, [1, 112, 112, 64], dtype="uint8") - T_cast_7 = T.match_buffer(T_cast_6, [1, 56, 56, 64], dtype="int16") + def tvmgen_default_fused_nn_max_pool2d_cast(placeholder_28: T.handle, T_cast_6: T.handle, fast_memory_6_var: T.Ptr[T.uint8], slow_memory_7_var: T.Ptr[T.uint8]) -> None: + placeholder_29 = T.match_buffer(placeholder_28, [802816], dtype="uint8") + T_cast_7 = T.match_buffer(T_cast_6, [177], dtype="int16") fast_memory_6_buffer_var = T.match_buffer(fast_memory_6_var, [200704], dtype="uint8", strides=[1], elem_offset=1, align=16) slow_memory_7_buffer_var = T.match_buffer(slow_memory_7_var, [1418528], dtype="uint8", strides=[1], elem_offset=1, align=16) # body - tensor_2_let: T.handle = T.address_of(T.load("uint8", fast_memory_6_buffer_var.data, 0), dtype="handle") - for ax0_ax1_fused_4, ax2_4 in T.grid(56, 56): - for ax3_init in T.serial(0, 64): - T.store(tensor_2_let, ax0_ax1_fused_4 * 3584 + ax2_4 * 64 + ax3_init, T.uint8(0), True) - for rv0_rv1_fused_1, ax3_2 in T.grid(9, 64): - T.store(tensor_2_let, ax0_ax1_fused_4 * 3584 + ax2_4 * 64 + ax3_2, T.max(T.load("uint8", tensor_2_let, ax0_ax1_fused_4 * 3584 + ax2_4 * 64 + ax3_2), T.if_then_else(ax0_ax1_fused_4 * 2 + rv0_rv1_fused_1 // 3 < 112 and ax2_4 * 2 + rv0_rv1_fused_1 % 3 < 112, T.load("uint8", placeholder_29.data, ax0_ax1_fused_4 * 14336 + rv0_rv1_fused_1 // 3 * 7168 + ax2_4 * 128 + rv0_rv1_fused_1 % 3 * 64 + ax3_2), T.uint8(0), dtype="uint8")), True) - for ax0_ax1_fused_5, ax2_5, ax3_3 in T.grid(56, 56, 64): - T.store(T_cast_7.data, ax0_ax1_fused_5 * 3584 + ax2_5 * 64 + ax3_3, T.cast(T.load("uint8", tensor_2_let, ax0_ax1_fused_5 * 3584 + ax2_5 * 64 + ax3_3), "int16"), True) + tensor_2_let = T.buffer_decl([200704], dtype="uint8") + with T.let(tensor_2_let.data, T.address_of(fast_memory_6_buffer_var[0], dtype="handle")): + for ax0_ax1_fused_4, ax2_4 in T.grid(56, 56): + for ax3_init in T.serial(0, 64): + tensor_2_let[ax0_ax1_fused_4 * 3584 + ax2_4 * 64 + ax3_init] = T.uint8(0) + for rv0_rv1_fused_1, ax3_2 in T.grid(9, 64): + tensor_2_let[ax0_ax1_fused_4 * 3584 + ax2_4 * 64 + ax3_2] = T.max(tensor_2_let[ax0_ax1_fused_4 * 3584 + ax2_4 * 64 + ax3_2], T.if_then_else(ax0_ax1_fused_4 * 2 + rv0_rv1_fused_1 // 3 < 112 and ax2_4 * 2 + rv0_rv1_fused_1 % 3 < 112, placeholder_29[ax0_ax1_fused_4 * 14336 + rv0_rv1_fused_1 // 3 * 7168 + ax2_4 * 128 + rv0_rv1_fused_1 % 3 * 64 + ax3_2], T.uint8(0), dtype="uint8")) + for ax0_ax1_fused_5, ax2_5, ax3_3 in T.grid(56, 56, 64): + T_cast_7[ax0_ax1_fused_5 * 3584 + ax2_5 * 64 + ax3_3] = T.cast(tensor_2_let[ax0_ax1_fused_5 * 3584 + ax2_5 * 64 + ax3_3], "int16") @T.prim_func - def tvmgen_default_fused_cast_subtract(placeholder_2: T.handle, placeholder_3: T.handle, T_subtract: T.handle, fast_memory_2_var: T.handle, slow_memory_3_var: T.handle) -> None: - placeholder_4 = T.match_buffer(placeholder_2, [1, 224, 224, 3], dtype="uint8") - placeholder_5 = T.match_buffer(placeholder_3, [], dtype="int16") - T_subtract_1 = T.match_buffer(T_subtract, [1, 224, 224, 3], dtype="int16") + def tvmgen_default_fused_cast_subtract(placeholder_2: T.handle, placeholder_3: T.handle, T_subtract: T.handle, fast_memory_2_var: T.Ptr[T.uint8], slow_memory_3_var: T.Ptr[T.uint8]) -> None: + placeholder_4 = T.match_buffer(placeholder_2, [150528], dtype="uint8") + placeholder_5 = T.match_buffer(placeholder_3, [1], dtype="int16") + T_subtract_1 = T.match_buffer(T_subtract, [452], dtype="int16") fast_memory_2_buffer_var = T.match_buffer(fast_memory_2_var, [200704], dtype="uint8", strides=[1], elem_offset=1, align=16) slow_memory_3_buffer_var = T.match_buffer(slow_memory_3_var, [1418528], dtype="uint8", strides=[1], elem_offset=1, align=16) # body for ax0_ax1_fused_1, ax2_1, ax3_inner_1 in T.grid(224, 224, 3): - T.store(T_subtract_1.data, ax0_ax1_fused_1 * 672 + ax2_1 * 3 + ax3_inner_1, T.cast(T.load("uint8", placeholder_4.data, ax0_ax1_fused_1 * 672 + ax2_1 * 3 + ax3_inner_1), "int16") - T.load("int16", placeholder_5.data, 0), True) + T_subtract_1[ax0_ax1_fused_1 * 672 + ax2_1 * 3 + ax3_inner_1] = T.cast(placeholder_4[ax0_ax1_fused_1 * 672 + ax2_1 * 3 + ax3_inner_1], "int16") - placeholder_5[0] @T.prim_func - def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast(placeholder_62: T.handle, placeholder_63: T.handle, placeholder_64: T.handle, T_cast_20: T.handle, fast_memory_4_var: T.handle, slow_memory_5_var: T.handle) -> None: - placeholder_65 = T.match_buffer(placeholder_62, [1, 224, 224, 3], dtype="int16") - placeholder_66 = T.match_buffer(placeholder_63, [7, 7, 3, 64], dtype="int16") - placeholder_67 = T.match_buffer(placeholder_64, [1, 1, 1, 64], dtype="int32") - T_cast_21 = T.match_buffer(T_cast_20, [1, 112, 112, 64], dtype="uint8") + def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast(placeholder_62: T.handle, placeholder_63: T.handle, placeholder_64: T.handle, T_cast_20: T.handle, fast_memory_4_var: T.Ptr[T.uint8], slow_memory_5_var: T.Ptr[T.uint8]) -> None: + placeholder_65 = T.match_buffer(placeholder_62, [150528], dtype="int16") + placeholder_66 = T.match_buffer(placeholder_63, [9408], dtype="int16") + placeholder_67 = T.match_buffer(placeholder_64, [64], dtype="int32") + T_cast_21 = T.match_buffer(T_cast_20, [289], dtype="uint8") fast_memory_4_buffer_var = T.match_buffer(fast_memory_4_var, [200704], dtype="uint8", strides=[1], elem_offset=1, align=16) slow_memory_5_buffer_var = T.match_buffer(slow_memory_5_var, [1418528], dtype="uint8", strides=[1], elem_offset=1, align=16) # body - PaddedInput_7_let: T.handle = T.address_of(T.load("uint8", slow_memory_5_buffer_var.data, 802816), dtype="handle") - for i0_i1_fused_7, i2_7, i3_7 in T.grid(229, 229, 3): - T.store(PaddedInput_7_let, i0_i1_fused_7 * 687 + i2_7 * 3 + i3_7, T.if_then_else(2 <= i0_i1_fused_7 and i0_i1_fused_7 < 226 and 2 <= i2_7 and i2_7 < 226, T.load("int16", placeholder_65.data, i0_i1_fused_7 * 672 + i2_7 * 3 + i3_7 - 1350), T.int16(0), dtype="int16"), True) - for ax0_ax1_fused_ax2_fused_7 in T.serial(0, 12544): - Conv2dOutput_7_let: T.handle = T.address_of(T.load("uint8", fast_memory_4_buffer_var.data, 0), dtype="handle") - for ff_3 in T.serial(0, 64): - T.store(Conv2dOutput_7_let, ff_3, 0, True) - for ry_2, rx_2, rc_7 in T.grid(7, 7, 3): - T.store(Conv2dOutput_7_let, ff_3, T.load("int32", Conv2dOutput_7_let, ff_3) + T.cast(T.load("int16", PaddedInput_7_let, ax0_ax1_fused_ax2_fused_7 // 112 * 1374 + ry_2 * 687 + ax0_ax1_fused_ax2_fused_7 % 112 * 6 + rx_2 * 3 + rc_7), "int32") * T.cast(T.load("int16", placeholder_66.data, ry_2 * 1344 + rx_2 * 192 + rc_7 * 64 + ff_3), "int32"), True) - for ax3_inner_7 in T.serial(0, 64): - T.store(T_cast_21.data, ax0_ax1_fused_ax2_fused_7 * 64 + ax3_inner_7, T.cast(T.max(T.min(T.q_multiply_shift(T.load("int32", Conv2dOutput_7_let, ax3_inner_7) + T.load("int32", placeholder_67.data, ax3_inner_7), 1939887962, 31, -9, dtype="int32"), 255), 0), "uint8"), True) + PaddedInput_7_let = T.buffer_decl([157323], "int16") + with T.let(PaddedInput_7_let.data, T.address_of(slow_memory_5_buffer_var[802816], dtype="handle")): + for i0_i1_fused_7, i2_7, i3_7 in T.grid(229, 229, 3): + PaddedInput_7_let[i0_i1_fused_7 * 687 + i2_7 * 3 + i3_7] = T.if_then_else(2 <= i0_i1_fused_7 and i0_i1_fused_7 < 226 and 2 <= i2_7 and i2_7 < 226, placeholder_65[i0_i1_fused_7 * 672 + i2_7 * 3 + i3_7 - 1350], T.int16(0), dtype="int16") + for ax0_ax1_fused_ax2_fused_7 in T.serial(0, 12544): + Conv2dOutput_7_let = T.buffer_decl([64], "int32") + with T.let(Conv2dOutput_7_let.data, T.address_of(fast_memory_4_buffer_var[0], dtype="handle")): + for ff_3 in T.serial(0, 64): + Conv2dOutput_7_let[ff_3] = 0 + for ry_2, rx_2, rc_7 in T.grid(7, 7, 3): + Conv2dOutput_7_let[ff_3] = Conv2dOutput_7_let[ff_3] + T.cast(PaddedInput_7_let[ax0_ax1_fused_ax2_fused_7 // 112 * 1374 + ry_2 * 687 + ax0_ax1_fused_ax2_fused_7 % 112 * 6 + rx_2 * 3 + rc_7], "int32") * T.cast(placeholder_66[ry_2 * 1344 + rx_2 * 192 + rc_7 * 64 + ff_3], "int32") + for ax3_inner_7 in T.serial(0, 64): + T_cast_21[ax0_ax1_fused_ax2_fused_7 * 64 + ax3_inner_7] = T.cast(T.max(T.min(T.q_multiply_shift(Conv2dOutput_7_let[ax3_inner_7] + placeholder_67[ax3_inner_7], 1939887962, 31, -9, dtype="int32"), 255), 0), "uint8") # fmt: on @@ -234,17 +237,10 @@ def test_mobilenet_subgraph(): )(tir_mod) tir_mod_with_offsets_ref = LinearStructurePlanned - tir_mod_with_offsets_ref = tvm.script.from_source( - tir_mod_with_offsets_ref.script(show_meta=False) - ) - # The TIR produced fails on roundtrip TVMScript testing. - # Therefore, indicates the TVMScript produced here and/or the parser - # is lacking functionality. Thus for these tests, uses a string - # version of the TVMScript for each function as a check instead. - for gv, func in tir_mod_with_offsets_ref.functions.items(): - assert str(tir_mod_with_offsets_ref[gv.name_hint].script()) == str( - tir_mod_with_offsets[gv.name_hint].script() - ) + + for gv, ref_func in tir_mod_with_offsets_ref.functions.items(): + actual_func = tir_mod_with_offsets[gv.name_hint] + tvm.ir.assert_structural_equal(actual_func, ref_func) # fmt: off @@ -254,78 +250,78 @@ class ResnetStructure: def tvmgen_default_fused_cast_subtract_fixed_point_multiply_add_clip_cast_cast(placeholder: T.handle, placeholder_1: T.handle, T_cast: T.handle) -> None: # function attr dict T.func_attr({"global_symbol": "tvmgen_default_fused_cast_subtract_fixed_point_multiply_add_clip_cast_cast", "tir.noalias": True}) - placeholder_2 = T.match_buffer(placeholder, [1, 75, 75, 64], dtype="uint8") + placeholder_2 = T.match_buffer(placeholder, [360000], dtype="uint8") placeholder_3 = T.match_buffer(placeholder_1, [64], dtype="int32") - T_cast_1 = T.match_buffer(T_cast, [1, 75, 75, 64], dtype="int16") + T_cast_1 = T.match_buffer(T_cast, [215], dtype="int16") # body for ax0_ax1_fused, ax2, ax3_outer, ax3_inner in T.grid(75, 75, 4, 16): - T.store(T_cast_1.data, ax0_ax1_fused * 4800 + ax2 * 64 + ax3_outer * 16 + ax3_inner, T.cast(T.cast(T.max(T.min(T.q_multiply_shift(T.cast(T.load("uint8", placeholder_2.data, ax0_ax1_fused * 4800 + ax2 * 64 + ax3_outer * 16 + ax3_inner), "int32") - 94, 1843157232, 31, 1, dtype="int32") + T.load("int32", placeholder_3.data, ax3_outer * 16 + ax3_inner), 255), 0), "uint8"), "int16"), True) + T_cast_1[ax0_ax1_fused * 4800 + ax2 * 64 + ax3_outer * 16 + ax3_inner] = T.cast(T.cast(T.max(T.min(T.q_multiply_shift(T.cast(placeholder_2[ax0_ax1_fused * 4800 + ax2 * 64 + ax3_outer * 16 + ax3_inner], "int32") - 94, 1843157232, 31, 1, dtype="int32") + placeholder_3[ax3_outer * 16 + ax3_inner], 255), 0), "uint8"), "int16") @T.prim_func def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast_1(placeholder_10: T.handle, placeholder_11: T.handle, placeholder_12: T.handle, T_cast_4: T.handle) -> None: # function attr dict T.func_attr({"global_symbol": "tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast_1", "tir.noalias": True}) - placeholder_13 = T.match_buffer(placeholder_10, [1, 75, 75, 64], dtype="int16") - placeholder_14 = T.match_buffer(placeholder_11, [3, 3, 64, 64], dtype="int16") - placeholder_15 = T.match_buffer(placeholder_12, [1, 1, 1, 64], dtype="int32") - T_cast_5 = T.match_buffer(T_cast_4, [1, 75, 75, 64], dtype="int16") + placeholder_13 = T.match_buffer(placeholder_10, [360000], dtype="int16") + placeholder_14 = T.match_buffer(placeholder_11, [36864], dtype="int16") + placeholder_15 = T.match_buffer(placeholder_12, [64], dtype="int32") + T_cast_5 = T.match_buffer(T_cast_4, [215], dtype="int16") # body PaddedInput_1 = T.allocate([379456], "int16", "global") for i0_i1_fused_1, i2_1, i3_1 in T.grid(77, 77, 64): - T.store(PaddedInput_1, i0_i1_fused_1 * 4928 + i2_1 * 64 + i3_1, T.if_then_else(1 <= i0_i1_fused_1 and i0_i1_fused_1 < 76 and 1 <= i2_1 and i2_1 < 76, T.load("int16", placeholder_13.data, i0_i1_fused_1 * 4800 + i2_1 * 64 + i3_1 - 4864), T.int16(0), dtype="int16"), True) + PaddedInput_1[i0_i1_fused_1 * 4928 + i2_1 * 64 + i3_1] = T.if_then_else(1 <= i0_i1_fused_1 and i0_i1_fused_1 < 76 and 1 <= i2_1 and i2_1 < 76, placeholder_13[i0_i1_fused_1 * 4800 + i2_1 * 64 + i3_1 - 4864], T.int16(0), dtype="int16") for ax0_ax1_fused_ax2_fused_1 in T.serial(0, 5625): Conv2dOutput_1 = T.allocate([64], "int32", "global") for ff_1 in T.serial(0, 64): - T.store(Conv2dOutput_1, ff_1, 0, True) + Conv2dOutput_1[ff_1] = 0 for ry, rx, rc_1 in T.grid(3, 3, 64): - T.store(Conv2dOutput_1, ff_1, T.load("int32", Conv2dOutput_1, ff_1) + T.cast(T.load("int16", PaddedInput_1, T.floordiv(ax0_ax1_fused_ax2_fused_1, 75) * 4928 + ry * 4928 + rx * 64 + T.floormod(ax0_ax1_fused_ax2_fused_1, 75) * 64 + rc_1), "int32") * T.cast(T.load("int16", placeholder_14.data, ry * 12288 + rx * 4096 + rc_1 * 64 + ff_1), "int32"), True) + Conv2dOutput_1[ff_1] = Conv2dOutput_1[ff_1] + T.cast(PaddedInput_1[T.floordiv(ax0_ax1_fused_ax2_fused_1, 75) * 4928 + ry * 4928 + rx * 64 + T.floormod(ax0_ax1_fused_ax2_fused_1, 75) * 64 + rc_1], "int32") * T.cast(placeholder_14[ry * 12288 + rx * 4096 + rc_1 * 64 + ff_1], "int32") for ax3_inner_2 in T.serial(0, 64): - T.store(T_cast_5.data, ax0_ax1_fused_ax2_fused_1 * 64 + ax3_inner_2, T.cast(T.cast(T.max(T.min(T.q_multiply_shift(T.load("int32", Conv2dOutput_1, ax3_inner_2) + T.load("int32", placeholder_15.data, ax3_inner_2), 1608879842, 31, -7, dtype="int32"), 255), 0), "uint8"), "int16"), True) + T_cast_5[ax0_ax1_fused_ax2_fused_1 * 64 + ax3_inner_2] = T.cast(T.cast(T.max(T.min(T.q_multiply_shift(Conv2dOutput_1[ax3_inner_2] + placeholder_15[ax3_inner_2], 1608879842, 31, -7, dtype="int32"), 255), 0), "uint8"), "int16") @T.prim_func def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_add_clip_cast_cast_subtract_fixed_point_15934180698220515269_(placeholder_16: T.handle, placeholder_17: T.handle, placeholder_18: T.handle, T_add: T.handle) -> None: # function attr dict T.func_attr({"global_symbol": "tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_add_clip_cast_cast_subtract_fixed_point_15934180698220515269_", "tir.noalias": True}) - placeholder_19 = T.match_buffer(placeholder_16, [1, 75, 75, 64], dtype="int16") - placeholder_20 = T.match_buffer(placeholder_17, [1, 1, 64, 256], dtype="int16") - placeholder_21 = T.match_buffer(placeholder_18, [1, 1, 1, 256], dtype="int32") - T_add_1 = T.match_buffer(T_add, [1, 75, 75, 256], dtype="int32") + placeholder_19 = T.match_buffer(placeholder_16, [360000], dtype="int16") + placeholder_20 = T.match_buffer(placeholder_17, [16384], dtype="int16") + placeholder_21 = T.match_buffer(placeholder_18, [256], dtype="int32") + T_add_1 = T.match_buffer(T_add, [407], dtype="int32") # body PaddedInput_2 = T.allocate([360000], "int16", "global") for i0_i1_fused_2, i2_2, i3_2 in T.grid(75, 75, 64): - T.store(PaddedInput_2, i0_i1_fused_2 * 4800 + i2_2 * 64 + i3_2, T.load("int16", placeholder_19.data, i0_i1_fused_2 * 4800 + i2_2 * 64 + i3_2), True) + PaddedInput_2[i0_i1_fused_2 * 4800 + i2_2 * 64 + i3_2] = placeholder_19[i0_i1_fused_2 * 4800 + i2_2 * 64 + i3_2] for ax0_ax1_fused_ax2_fused_2 in T.serial(0, 5625): Conv2dOutput_2 = T.allocate([64], "int32", "global") for ax3_outer_1 in T.serial(0, 4): for ff_2 in T.serial(0, 64): - T.store(Conv2dOutput_2, ff_2, 0, True) + Conv2dOutput_2[ff_2] = 0 for rc_2 in T.serial(0, 64): - T.store(Conv2dOutput_2, ff_2, T.load("int32", Conv2dOutput_2, ff_2) + T.cast(T.load("int16", PaddedInput_2, ax0_ax1_fused_ax2_fused_2 * 64 + rc_2), "int32") * T.cast(T.load("int16", placeholder_20.data, rc_2 * 256 + ax3_outer_1 * 64 + ff_2), "int32"), True) + Conv2dOutput_2[ff_2] = Conv2dOutput_2[ff_2] + T.cast(PaddedInput_2[ax0_ax1_fused_ax2_fused_2 * 64 + rc_2], "int32") * T.cast(placeholder_20[rc_2 * 256 + ax3_outer_1 * 64 + ff_2], "int32") for ax3_inner_3 in T.serial(0, 64): - T.store(T_add_1.data, ax0_ax1_fused_ax2_fused_2 * 256 + ax3_outer_1 * 64 + ax3_inner_3, T.q_multiply_shift(T.cast(T.cast(T.max(T.min(T.q_multiply_shift(T.load("int32", Conv2dOutput_2, ax3_inner_3) + T.load("int32", placeholder_21.data, ax3_outer_1 * 64 + ax3_inner_3), 1711626602, 31, -8, dtype="int32") + 132, 255), 0), "uint8"), "int32") - 132, 2094289803, 31, -2, dtype="int32") + 136, True) + T_add_1[ax0_ax1_fused_ax2_fused_2 * 256 + ax3_outer_1 * 64 + ax3_inner_3] = T.q_multiply_shift(T.cast(T.cast(T.max(T.min(T.q_multiply_shift(Conv2dOutput_2[ax3_inner_3] + placeholder_21[ax3_outer_1 * 64 + ax3_inner_3], 1711626602, 31, -8, dtype="int32") + 132, 255), 0), "uint8"), "int32") - 132, 2094289803, 31, -2, dtype="int32") + 136 @T.prim_func def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_add_clip_cast_cast_subtract_fixed_point_4200876283395191415_(placeholder_22: T.handle, placeholder_23: T.handle, placeholder_24: T.handle, placeholder_25: T.handle, T_cast_6: T.handle) -> None: # function attr dict T.func_attr({"global_symbol": "tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_add_clip_cast_cast_subtract_fixed_point_4200876283395191415_", "tir.noalias": True}) - placeholder_29 = T.match_buffer(placeholder_22, [1, 75, 75, 64], dtype="int16") - placeholder_27 = T.match_buffer(placeholder_23, [1, 1, 64, 256], dtype="int16") - placeholder_26 = T.match_buffer(placeholder_24, [1, 1, 1, 256], dtype="int32") - placeholder_28 = T.match_buffer(placeholder_25, [1, 75, 75, 256], dtype="int32") - T_cast_7 = T.match_buffer(T_cast_6, [1, 75, 75, 256], dtype="uint8") + placeholder_29 = T.match_buffer(placeholder_22, [360000], dtype="int16") + placeholder_27 = T.match_buffer(placeholder_23, [16384], dtype="int16") + placeholder_26 = T.match_buffer(placeholder_24, [256], dtype="int32") + placeholder_28 = T.match_buffer(placeholder_25, [1440000], dtype="int32") + T_cast_7 = T.match_buffer(T_cast_6, [407], dtype="uint8") # body PaddedInput_3 = T.allocate([360000], "int16", "global") for i0_i1_fused_3, i2_3, i3_3 in T.grid(75, 75, 64): - T.store(PaddedInput_3, i0_i1_fused_3 * 4800 + i2_3 * 64 + i3_3, T.load("int16", placeholder_29.data, i0_i1_fused_3 * 4800 + i2_3 * 64 + i3_3), True) + PaddedInput_3[i0_i1_fused_3 * 4800 + i2_3 * 64 + i3_3] = placeholder_29[i0_i1_fused_3 * 4800 + i2_3 * 64 + i3_3] for ax0_ax1_fused_ax2_fused_3 in T.serial(0, 5625): Conv2dOutput_3 = T.allocate([64], "int32", "global") for ax3_outer_2 in T.serial(0, 4): for ff_3 in T.serial(0, 64): - T.store(Conv2dOutput_3, ff_3, 0, True) + Conv2dOutput_3[ff_3] = 0 for rc_3 in T.serial(0, 64): - T.store(Conv2dOutput_3, ff_3, T.load("int32", Conv2dOutput_3, ff_3) + T.cast(T.load("int16", PaddedInput_3, ax0_ax1_fused_ax2_fused_3 * 64 + rc_3), "int32") * T.cast(T.load("int16", placeholder_27.data, rc_3 * 256 + ax3_outer_2 * 64 + ff_3), "int32"), True) + Conv2dOutput_3[ff_3] = Conv2dOutput_3[ff_3] + T.cast(PaddedInput_3[ax0_ax1_fused_ax2_fused_3 * 64 + rc_3], "int32") * T.cast(placeholder_27[rc_3 * 256 + ax3_outer_2 * 64 + ff_3], "int32") for ax3_inner_4 in T.serial(0, 64): - T.store(T_cast_7.data, ax0_ax1_fused_ax2_fused_3 * 256 + ax3_outer_2 * 64 + ax3_inner_4, T.cast(T.max(T.min(T.q_multiply_shift(T.cast(T.cast(T.max(T.min(T.q_multiply_shift(T.load("int32", Conv2dOutput_3, ax3_inner_4) + T.load("int32", placeholder_26.data, ax3_outer_2 * 64 + ax3_inner_4), 1343014664, 31, -8, dtype="int32") + 136, 255), 0), "uint8"), "int32") - 136, 1073903788, 31, 1, dtype="int32") + T.load("int32", placeholder_28.data, ax0_ax1_fused_ax2_fused_3 * 256 + ax3_outer_2 * 64 + ax3_inner_4), 255), 0), "uint8"), True) + T_cast_7[ax0_ax1_fused_ax2_fused_3 * 256 + ax3_outer_2 * 64 + ax3_inner_4] = T.cast(T.max(T.min(T.q_multiply_shift(T.cast(T.cast(T.max(T.min(T.q_multiply_shift(Conv2dOutput_3[ax3_inner_4] + placeholder_26[ax3_outer_2 * 64 + ax3_inner_4], 1343014664, 31, -8, dtype="int32") + 136, 255), 0), "uint8"), "int32") - 136, 1073903788, 31, 1, dtype="int32") + placeholder_28[ax0_ax1_fused_ax2_fused_3 * 256 + ax3_outer_2 * 64 + ax3_inner_4], 255), 0), "uint8") @T.prim_func def __tvm_main__(input: T.handle, output: T.handle) -> None: @@ -338,32 +334,32 @@ def __tvm_main__(input: T.handle, output: T.handle) -> None: sid_6 = T.allocate([5760000], "int8", "global") sid_7 = T.allocate([720000], "int8", "global") sid_8 = T.allocate([720000], "int8", "global") - T.evaluate(T.call_extern("tvmgen_default_fused_cast_subtract_fixed_point_multiply_add_clip_cast_cast", input, T.lookup_param("p0", dtype="handle"), sid_2, dtype="int32")) - T.evaluate(T.call_extern("tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast", sid_2, T.lookup_param("p3", dtype="handle"), T.lookup_param("p4", dtype="handle"), sid_8, dtype="int32")) - T.evaluate(T.call_extern("tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast_1", sid_8, T.lookup_param("p5", dtype="handle"), T.lookup_param("p6", dtype="handle"), sid_7, dtype="int32")) - T.evaluate(T.call_extern("tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_add_clip_cast_cast_subtract_fixed_point_15934180698220515269_", sid_7, T.lookup_param("p7", dtype="handle"), T.lookup_param("p8", dtype="handle"), sid_6, dtype="int32")) - T.evaluate(T.call_extern("tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_add_clip_cast_cast_subtract_fixed_point_4200876283395191415_", sid_2, T.lookup_param("p1", dtype="handle"), T.lookup_param("p2", dtype="handle"), sid_6, output, dtype="int32")) + T.evaluate(T.call_extern("tvmgen_default_fused_cast_subtract_fixed_point_multiply_add_clip_cast_cast", input, T.lookup_param("p0", dtype="handle"), sid_2.data, dtype="int32")) + T.evaluate(T.call_extern("tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast", sid_2.data, T.lookup_param("p3", dtype="handle"), T.lookup_param("p4", dtype="handle"), sid_8.data, dtype="int32")) + T.evaluate(T.call_extern("tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast_1", sid_8.data, T.lookup_param("p5", dtype="handle"), T.lookup_param("p6", dtype="handle"), sid_7.data, dtype="int32")) + T.evaluate(T.call_extern("tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_add_clip_cast_cast_subtract_fixed_point_15934180698220515269_", sid_7.data, T.lookup_param("p7", dtype="handle"), T.lookup_param("p8", dtype="handle"), sid_6.data, dtype="int32")) + T.evaluate(T.call_extern("tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_add_clip_cast_cast_subtract_fixed_point_4200876283395191415_", sid_2.data, T.lookup_param("p1", dtype="handle"), T.lookup_param("p2", dtype="handle"), sid_6.data, output, dtype="int32")) @T.prim_func def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast(placeholder_4: T.handle, placeholder_5: T.handle, placeholder_6: T.handle, T_cast_2: T.handle) -> None: # function attr dict T.func_attr({"global_symbol": "tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast", "tir.noalias": True}) - placeholder_7 = T.match_buffer(placeholder_4, [1, 75, 75, 64], dtype="int16") - placeholder_8 = T.match_buffer(placeholder_5, [1, 1, 64, 64], dtype="int16") - placeholder_9 = T.match_buffer(placeholder_6, [1, 1, 1, 64], dtype="int32") - T_cast_3 = T.match_buffer(T_cast_2, [1, 75, 75, 64], dtype="int16") + placeholder_7 = T.match_buffer(placeholder_4, [360000], dtype="int16") + placeholder_8 = T.match_buffer(placeholder_5, [4096], dtype="int16") + placeholder_9 = T.match_buffer(placeholder_6, [64], dtype="int32") + T_cast_3 = T.match_buffer(T_cast_2, [215], dtype="int16") # body PaddedInput = T.allocate([360000], "int16", "global") for i0_i1_fused, i2, i3 in T.grid(75, 75, 64): - T.store(PaddedInput, i0_i1_fused * 4800 + i2 * 64 + i3, T.load("int16", placeholder_7.data, i0_i1_fused * 4800 + i2 * 64 + i3), True) + PaddedInput[i0_i1_fused * 4800 + i2 * 64 + i3] = placeholder_7[i0_i1_fused * 4800 + i2 * 64 + i3] for ax0_ax1_fused_ax2_fused in T.serial(0, 5625): Conv2dOutput = T.allocate([64], "int32", "global") for ff in T.serial(0, 64): - T.store(Conv2dOutput, ff, 0, True) + Conv2dOutput[ff] = 0 for rc in T.serial(0, 64): - T.store(Conv2dOutput, ff, T.load("int32", Conv2dOutput, ff) + T.cast(T.load("int16", PaddedInput, ax0_ax1_fused_ax2_fused * 64 + rc), "int32") * T.cast(T.load("int16", placeholder_8.data, rc * 64 + ff), "int32"), True) + Conv2dOutput[ff] = Conv2dOutput[ff] + T.cast(PaddedInput[ax0_ax1_fused_ax2_fused * 64 + rc], "int32") * T.cast(placeholder_8[rc * 64 + ff], "int32") for ax3_inner_1 in T.serial(0, 64): - T.store(T_cast_3.data, ax0_ax1_fused_ax2_fused * 64 + ax3_inner_1, T.cast(T.cast(T.max(T.min(T.q_multiply_shift(T.load("int32", Conv2dOutput, ax3_inner_1) + T.load("int32", placeholder_9.data, ax3_inner_1), 1843106743, 31, -6, dtype="int32"), 255), 0), "uint8"), "int16"), True) + T_cast_3[ax0_ax1_fused_ax2_fused * 64 + ax3_inner_1] = T.cast(T.cast(T.max(T.min(T.q_multiply_shift(Conv2dOutput[ax3_inner_1] + placeholder_9[ax3_inner_1], 1843106743, 31, -6, dtype="int32"), 255), 0), "uint8"), "int16") # fmt: on @@ -371,108 +367,116 @@ def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast(place @tvm.script.ir_module class ResnetStructurePlanned: @T.prim_func - def tvmgen_default_fused_cast_subtract_fixed_point_multiply_add_clip_cast_cast(placeholder: T.handle, placeholder_1: T.handle, T_cast: T.handle, global_workspace_1_var: T.handle) -> None: - placeholder_2 = T.match_buffer(placeholder, [1, 75, 75, 64], dtype="uint8") + def tvmgen_default_fused_cast_subtract_fixed_point_multiply_add_clip_cast_cast(placeholder: T.handle, placeholder_1: T.handle, T_cast: T.handle, global_workspace_1_var: T.Ptr[T.uint8]) -> None: + placeholder_2 = T.match_buffer(placeholder, [360000], dtype="uint8") placeholder_3 = T.match_buffer(placeholder_1, [64], dtype="int32") - T_cast_1 = T.match_buffer(T_cast, [1, 75, 75, 64], dtype="int16") + T_cast_1 = T.match_buffer(T_cast, [215], dtype="int16") global_workspace_1_buffer_var = T.match_buffer(global_workspace_1_var, [7920256], dtype="uint8", strides=[1], elem_offset=1, align=16) # body for ax0_ax1_fused, ax2, ax3_outer, ax3_inner in T.grid(75, 75, 4, 16): - T.store(T_cast_1.data, ax0_ax1_fused * 4800 + ax2 * 64 + ax3_outer * 16 + ax3_inner, T.cast(T.cast(T.max(T.min(T.q_multiply_shift(T.cast(T.load("uint8", placeholder_2.data, ax0_ax1_fused * 4800 + ax2 * 64 + ax3_outer * 16 + ax3_inner), "int32") - 94, 1843157232, 31, 1, dtype="int32") + T.load("int32", placeholder_3.data, ax3_outer * 16 + ax3_inner), 255), 0), "uint8"), "int16"), True) + T_cast_1[ax0_ax1_fused * 4800 + ax2 * 64 + ax3_outer * 16 + ax3_inner] = T.cast(T.cast(T.max(T.min(T.q_multiply_shift(T.cast(placeholder_2[ax0_ax1_fused * 4800 + ax2 * 64 + ax3_outer * 16 + ax3_inner], "int32") - 94, 1843157232, 31, 1, dtype="int32") + placeholder_3[ax3_outer * 16 + ax3_inner], 255), 0), "uint8"), "int16") @T.prim_func - def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_add_clip_cast_cast_subtract_fixed_point_4200876283395191415_(placeholder_22: T.handle, placeholder_23: T.handle, placeholder_24: T.handle, placeholder_25: T.handle, T_cast_6: T.handle, global_workspace_5_var: T.handle) -> None: - placeholder_29 = T.match_buffer(placeholder_22, [1, 75, 75, 64], dtype="int16") - placeholder_27 = T.match_buffer(placeholder_23, [1, 1, 64, 256], dtype="int16") - placeholder_26 = T.match_buffer(placeholder_24, [1, 1, 1, 256], dtype="int32") - placeholder_28 = T.match_buffer(placeholder_25, [1, 75, 75, 256], dtype="int32") - T_cast_7 = T.match_buffer(T_cast_6, [1, 75, 75, 256], dtype="uint8") + def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_add_clip_cast_cast_subtract_fixed_point_4200876283395191415_(placeholder_22: T.handle, placeholder_23: T.handle, placeholder_24: T.handle, placeholder_25: T.handle, T_cast_6: T.handle, global_workspace_5_var: T.Ptr[T.uint8]) -> None: + placeholder_29 = T.match_buffer(placeholder_22, [360000], dtype="int16") + placeholder_27 = T.match_buffer(placeholder_23, [16384], dtype="int16") + placeholder_26 = T.match_buffer(placeholder_24, [256], dtype="int32") + placeholder_28 = T.match_buffer(placeholder_25, [1440000], dtype="int32") + T_cast_7 = T.match_buffer(T_cast_6, [407], dtype="uint8") global_workspace_5_buffer_var = T.match_buffer(global_workspace_5_var, [7920256], dtype="uint8", strides=[1], elem_offset=1, align=16) # body - PaddedInput_3_let: T.handle = T.address_of(T.load("uint8", global_workspace_5_buffer_var.data, 6480000), dtype="handle") - for i0_i1_fused_3, i2_3, i3_3 in T.grid(75, 75, 64): - T.store(PaddedInput_3_let, i0_i1_fused_3 * 4800 + i2_3 * 64 + i3_3, T.load("int16", placeholder_29.data, i0_i1_fused_3 * 4800 + i2_3 * 64 + i3_3), True) - for ax0_ax1_fused_ax2_fused_3 in T.serial(0, 5625): - Conv2dOutput_3_let: T.handle = T.address_of(T.load("uint8", global_workspace_5_buffer_var.data, 7200000), dtype="handle") - for ax3_outer_2 in T.serial(0, 4): - for ff_3 in T.serial(0, 64): - T.store(Conv2dOutput_3_let, ff_3, 0, True) - for rc_3 in T.serial(0, 64): - T.store(Conv2dOutput_3_let, ff_3, T.load("int32", Conv2dOutput_3_let, ff_3) + T.cast(T.load("int16", PaddedInput_3_let, ax0_ax1_fused_ax2_fused_3 * 64 + rc_3), "int32") * T.cast(T.load("int16", placeholder_27.data, rc_3 * 256 + ax3_outer_2 * 64 + ff_3), "int32"), True) - for ax3_inner_4 in T.serial(0, 64): - T.store(T_cast_7.data, ax0_ax1_fused_ax2_fused_3 * 256 + ax3_outer_2 * 64 + ax3_inner_4, T.cast(T.max(T.min(T.q_multiply_shift(T.cast(T.cast(T.max(T.min(T.q_multiply_shift(T.load("int32", Conv2dOutput_3_let, ax3_inner_4) + T.load("int32", placeholder_26.data, ax3_outer_2 * 64 + ax3_inner_4), 1343014664, 31, -8, dtype="int32") + 136, 255), 0), "uint8"), "int32") - 136, 1073903788, 31, 1, dtype="int32") + T.load("int32", placeholder_28.data, ax0_ax1_fused_ax2_fused_3 * 256 + ax3_outer_2 * 64 + ax3_inner_4), 255), 0), "uint8"), True) + PaddedInput_3_let = T.buffer_decl([360000], 'int16') + with T.let(PaddedInput_3_let.data, T.address_of(global_workspace_5_buffer_var[6480000], dtype="handle")): + for i0_i1_fused_3, i2_3, i3_3 in T.grid(75, 75, 64): + PaddedInput_3_let[i0_i1_fused_3 * 4800 + i2_3 * 64 + i3_3] = placeholder_29[i0_i1_fused_3 * 4800 + i2_3 * 64 + i3_3] + for ax0_ax1_fused_ax2_fused_3 in T.serial(0, 5625): + Conv2dOutput_3_let = T.buffer_decl([64], 'int32') + with T.let(Conv2dOutput_3_let.data, T.address_of(global_workspace_5_buffer_var[7200000], dtype="handle")): + for ax3_outer_2 in T.serial(0, 4): + for ff_3 in T.serial(0, 64): + Conv2dOutput_3_let[ff_3] = 0 + for rc_3 in T.serial(0, 64): + Conv2dOutput_3_let[ff_3] = Conv2dOutput_3_let[ff_3] + T.cast(PaddedInput_3_let[ax0_ax1_fused_ax2_fused_3 * 64 + rc_3], "int32") * T.cast(placeholder_27[rc_3 * 256 + ax3_outer_2 * 64 + ff_3], "int32") + for ax3_inner_4 in T.serial(0, 64): + T_cast_7[ax0_ax1_fused_ax2_fused_3 * 256 + ax3_outer_2 * 64 + ax3_inner_4] = T.cast(T.max(T.min(T.q_multiply_shift(T.cast(T.cast(T.max(T.min(T.q_multiply_shift(Conv2dOutput_3_let[ax3_inner_4] + placeholder_26[ax3_outer_2 * 64 + ax3_inner_4], 1343014664, 31, -8, dtype="int32") + 136, 255), 0), "uint8"), "int32") - 136, 1073903788, 31, 1, dtype="int32") + placeholder_28[ax0_ax1_fused_ax2_fused_3 * 256 + ax3_outer_2 * 64 + ax3_inner_4], 255), 0), "uint8") @T.prim_func - def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_add_clip_cast_cast_subtract_fixed_point_15934180698220515269_(placeholder_16: T.handle, placeholder_17: T.handle, placeholder_18: T.handle, T_add: T.handle, global_workspace_4_var: T.handle) -> None: - placeholder_19 = T.match_buffer(placeholder_16, [1, 75, 75, 64], dtype="int16") - placeholder_20 = T.match_buffer(placeholder_17, [1, 1, 64, 256], dtype="int16") - placeholder_21 = T.match_buffer(placeholder_18, [1, 1, 1, 256], dtype="int32") - T_add_1 = T.match_buffer(T_add, [1, 75, 75, 256], dtype="int32") + def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_add_clip_cast_cast_subtract_fixed_point_15934180698220515269_(placeholder_16: T.handle, placeholder_17: T.handle, placeholder_18: T.handle, T_add: T.handle, global_workspace_4_var: T.Ptr[T.uint8]) -> None: + placeholder_19 = T.match_buffer(placeholder_16, [360000], dtype="int16") + placeholder_20 = T.match_buffer(placeholder_17, [16384], dtype="int16") + placeholder_21 = T.match_buffer(placeholder_18, [256], dtype="int32") + T_add_1 = T.match_buffer(T_add, [407], dtype="int32") global_workspace_4_buffer_var = T.match_buffer(global_workspace_4_var, [7920256], dtype="uint8", strides=[1], elem_offset=1, align=16) # body - PaddedInput_2_let: T.handle = T.address_of(T.load("uint8", global_workspace_4_buffer_var.data, 7200000), dtype="handle") - for i0_i1_fused_2, i2_2, i3_2 in T.grid(75, 75, 64): - T.store(PaddedInput_2_let, i0_i1_fused_2 * 4800 + i2_2 * 64 + i3_2, T.load("int16", placeholder_19.data, i0_i1_fused_2 * 4800 + i2_2 * 64 + i3_2), True) - for ax0_ax1_fused_ax2_fused_2 in T.serial(0, 5625): - Conv2dOutput_2_let: T.handle = T.address_of(T.load("uint8", global_workspace_4_buffer_var.data, 7920000), dtype="handle") - for ax3_outer_1 in T.serial(0, 4): - for ff_2 in T.serial(0, 64): - T.store(Conv2dOutput_2_let, ff_2, 0, True) - for rc_2 in T.serial(0, 64): - T.store(Conv2dOutput_2_let, ff_2, T.load("int32", Conv2dOutput_2_let, ff_2) + T.cast(T.load("int16", PaddedInput_2_let, ax0_ax1_fused_ax2_fused_2 * 64 + rc_2), "int32") * T.cast(T.load("int16", placeholder_20.data, rc_2 * 256 + ax3_outer_1 * 64 + ff_2), "int32"), True) - for ax3_inner_3 in T.serial(0, 64): - T.store(T_add_1.data, ax0_ax1_fused_ax2_fused_2 * 256 + ax3_outer_1 * 64 + ax3_inner_3, T.q_multiply_shift(T.cast(T.cast(T.max(T.min(T.q_multiply_shift(T.load("int32", Conv2dOutput_2_let, ax3_inner_3) + T.load("int32", placeholder_21.data, ax3_outer_1 * 64 + ax3_inner_3), 1711626602, 31, -8, dtype="int32") + 132, 255), 0), "uint8"), "int32") - 132, 2094289803, 31, -2, dtype="int32") + 136, True) + PaddedInput_2_let = T.buffer_decl([360000], "int16") + with T.let(PaddedInput_2_let.data, T.address_of(global_workspace_4_buffer_var[7200000], dtype="handle")): + for i0_i1_fused_2, i2_2, i3_2 in T.grid(75, 75, 64): + PaddedInput_2_let[i0_i1_fused_2 * 4800 + i2_2 * 64 + i3_2] = placeholder_19[i0_i1_fused_2 * 4800 + i2_2 * 64 + i3_2] + for ax0_ax1_fused_ax2_fused_2 in T.serial(0, 5625): + Conv2dOutput_2_let = T.buffer_decl([64], 'int32') + with T.let(Conv2dOutput_2_let.data, T.address_of(global_workspace_4_buffer_var[7920000], dtype="handle")): + for ax3_outer_1 in T.serial(0, 4): + for ff_2 in T.serial(0, 64): + Conv2dOutput_2_let[ff_2] = 0 + for rc_2 in T.serial(0, 64): + Conv2dOutput_2_let[ff_2] = Conv2dOutput_2_let[ff_2] + T.cast(PaddedInput_2_let[ax0_ax1_fused_ax2_fused_2 * 64 + rc_2], "int32") * T.cast(placeholder_20[rc_2 * 256 + ax3_outer_1 * 64 + ff_2], "int32") + for ax3_inner_3 in T.serial(0, 64): + T_add_1[ax0_ax1_fused_ax2_fused_2 * 256 + ax3_outer_1 * 64 + ax3_inner_3] = T.q_multiply_shift(T.cast(T.cast(T.max(T.min(T.q_multiply_shift(Conv2dOutput_2_let[ax3_inner_3] + placeholder_21[ax3_outer_1 * 64 + ax3_inner_3], 1711626602, 31, -8, dtype="int32") + 132, 255), 0), "uint8"), "int32") - 132, 2094289803, 31, -2, dtype="int32") + 136 @T.prim_func - def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast(placeholder_4: T.handle, placeholder_5: T.handle, placeholder_6: T.handle, T_cast_2: T.handle, global_workspace_2_var: T.handle) -> None: - placeholder_7 = T.match_buffer(placeholder_4, [1, 75, 75, 64], dtype="int16") - placeholder_8 = T.match_buffer(placeholder_5, [1, 1, 64, 64], dtype="int16") - placeholder_9 = T.match_buffer(placeholder_6, [1, 1, 1, 64], dtype="int32") - T_cast_3 = T.match_buffer(T_cast_2, [1, 75, 75, 64], dtype="int16") + def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast(placeholder_4: T.handle, placeholder_5: T.handle, placeholder_6: T.handle, T_cast_2: T.handle, global_workspace_2_var: T.Ptr[T.uint8]) -> None: + placeholder_7 = T.match_buffer(placeholder_4, [360000], dtype="int16") + placeholder_8 = T.match_buffer(placeholder_5, [4096], dtype="int16") + placeholder_9 = T.match_buffer(placeholder_6, [64], dtype="int32") + T_cast_3 = T.match_buffer(T_cast_2, [215], dtype="int16") global_workspace_2_buffer_var = T.match_buffer(global_workspace_2_var, [7920256], dtype="uint8", strides=[1], elem_offset=1, align=16) # body - PaddedInput_let: T.handle = T.address_of(T.load("uint8", global_workspace_2_buffer_var.data, 7200000), dtype="handle") - for i0_i1_fused, i2, i3 in T.grid(75, 75, 64): - T.store(PaddedInput_let, i0_i1_fused * 4800 + i2 * 64 + i3, T.load("int16", placeholder_7.data, i0_i1_fused * 4800 + i2 * 64 + i3), True) - for ax0_ax1_fused_ax2_fused in T.serial(0, 5625): - Conv2dOutput_let: T.handle = T.address_of(T.load("uint8", global_workspace_2_buffer_var.data, 7920000), dtype="handle") - for ff in T.serial(0, 64): - T.store(Conv2dOutput_let, ff, 0, True) - for rc in T.serial(0, 64): - T.store(Conv2dOutput_let, ff, T.load("int32", Conv2dOutput_let, ff) + T.cast(T.load("int16", PaddedInput_let, ax0_ax1_fused_ax2_fused * 64 + rc), "int32") * T.cast(T.load("int16", placeholder_8.data, rc * 64 + ff), "int32"), True) - for ax3_inner_1 in T.serial(0, 64): - T.store(T_cast_3.data, ax0_ax1_fused_ax2_fused * 64 + ax3_inner_1, T.cast(T.cast(T.max(T.min(T.q_multiply_shift(T.load("int32", Conv2dOutput_let, ax3_inner_1) + T.load("int32", placeholder_9.data, ax3_inner_1), 1843106743, 31, -6, dtype="int32"), 255), 0), "uint8"), "int16"), True) + PaddedInput_let = T.buffer_decl([360000], "int16") + with T.let(PaddedInput_let.data, T.address_of(global_workspace_2_buffer_var[7200000], dtype="handle")): + for i0_i1_fused, i2, i3 in T.grid(75, 75, 64): + PaddedInput_let[i0_i1_fused * 4800 + i2 * 64 + i3] = placeholder_7[i0_i1_fused * 4800 + i2 * 64 + i3] + for ax0_ax1_fused_ax2_fused in T.serial(0, 5625): + Conv2dOutput_let = T.buffer_decl([64], "int32") + with T.let(Conv2dOutput_let.data, T.address_of(global_workspace_2_buffer_var[7920000], dtype="handle")): + for ff in T.serial(0, 64): + Conv2dOutput_let[ff] = 0 + for rc in T.serial(0, 64): + Conv2dOutput_let[ff] = Conv2dOutput_let[ff] + T.cast(PaddedInput_let[ax0_ax1_fused_ax2_fused * 64 + rc], "int32") * T.cast(placeholder_8[rc * 64 + ff], "int32") + for ax3_inner_1 in T.serial(0, 64): + T_cast_3[ax0_ax1_fused_ax2_fused * 64 + ax3_inner_1] = T.cast(T.cast(T.max(T.min(T.q_multiply_shift(Conv2dOutput_let[ax3_inner_1] + placeholder_9[ax3_inner_1], 1843106743, 31, -6, dtype="int32"), 255), 0), "uint8"), "int16") @T.prim_func - def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast_1(placeholder_10: T.handle, placeholder_11: T.handle, placeholder_12: T.handle, T_cast_4: T.handle, global_workspace_3_var: T.handle) -> None: - placeholder_13 = T.match_buffer(placeholder_10, [1, 75, 75, 64], dtype="int16") - placeholder_14 = T.match_buffer(placeholder_11, [3, 3, 64, 64], dtype="int16") - placeholder_15 = T.match_buffer(placeholder_12, [1, 1, 1, 64], dtype="int32") - T_cast_5 = T.match_buffer(T_cast_4, [1, 75, 75, 64], dtype="int16") + def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast_1(placeholder_10: T.handle, placeholder_11: T.handle, placeholder_12: T.handle, T_cast_4: T.handle, global_workspace_3_var: T.Ptr[T.uint8]) -> None: + placeholder_13 = T.match_buffer(placeholder_10, [360000], dtype="int16") + placeholder_14 = T.match_buffer(placeholder_11, [36864], dtype="int16") + placeholder_15 = T.match_buffer(placeholder_12, [64], dtype="int32") + T_cast_5 = T.match_buffer(T_cast_4, [215], dtype="int16") global_workspace_3_buffer_var = T.match_buffer(global_workspace_3_var, [7920256], dtype="uint8", strides=[1], elem_offset=1, align=16) # body - PaddedInput_1_let: T.handle = T.address_of(T.load("uint8", global_workspace_3_buffer_var.data, 0), dtype="handle") - for i0_i1_fused_1, i2_1, i3_1 in T.grid(77, 77, 64): - T.store(PaddedInput_1_let, i0_i1_fused_1 * 4928 + i2_1 * 64 + i3_1, T.if_then_else(1 <= i0_i1_fused_1 and i0_i1_fused_1 < 76 and 1 <= i2_1 and i2_1 < 76, T.load("int16", placeholder_13.data, i0_i1_fused_1 * 4800 + i2_1 * 64 + i3_1 - 4864), T.int16(0), dtype="int16"), True) - for ax0_ax1_fused_ax2_fused_1 in T.serial(0, 5625): - Conv2dOutput_1_let: T.handle = T.address_of(T.load("uint8", global_workspace_3_buffer_var.data, 7200000), dtype="handle") - for ff_1 in T.serial(0, 64): - T.store(Conv2dOutput_1_let, ff_1, 0, True) - for ry, rx, rc_1 in T.grid(3, 3, 64): - T.store(Conv2dOutput_1_let, ff_1, T.load("int32", Conv2dOutput_1_let, ff_1) + T.cast(T.load("int16", PaddedInput_1_let, ax0_ax1_fused_ax2_fused_1 // 75 * 4928 + ry * 4928 + rx * 64 + ax0_ax1_fused_ax2_fused_1 % 75 * 64 + rc_1), "int32") * T.cast(T.load("int16", placeholder_14.data, ry * 12288 + rx * 4096 + rc_1 * 64 + ff_1), "int32"), True) - for ax3_inner_2 in T.serial(0, 64): - T.store(T_cast_5.data, ax0_ax1_fused_ax2_fused_1 * 64 + ax3_inner_2, T.cast(T.cast(T.max(T.min(T.q_multiply_shift(T.load("int32", Conv2dOutput_1_let, ax3_inner_2) + T.load("int32", placeholder_15.data, ax3_inner_2), 1608879842, 31, -7, dtype="int32"), 255), 0), "uint8"), "int16"), True) + PaddedInput_1_let = T.buffer_decl([379456], "int16") + with T.let(PaddedInput_1_let.data, T.address_of(global_workspace_3_buffer_var[0], dtype="handle")): + for i0_i1_fused_1, i2_1, i3_1 in T.grid(77, 77, 64): + PaddedInput_1_let[i0_i1_fused_1 * 4928 + i2_1 * 64 + i3_1] = T.if_then_else(1 <= i0_i1_fused_1 and i0_i1_fused_1 < 76 and 1 <= i2_1 and i2_1 < 76, placeholder_13[i0_i1_fused_1 * 4800 + i2_1 * 64 + i3_1 - 4864], T.int16(0), dtype="int16") + for ax0_ax1_fused_ax2_fused_1 in T.serial(0, 5625): + Conv2dOutput_1_let = T.buffer_decl([64], "int32") + with T.let(Conv2dOutput_1_let.data, T.address_of(global_workspace_3_buffer_var[7200000], dtype="handle")): + for ff_1 in T.serial(0, 64): + Conv2dOutput_1_let[ff_1] = 0 + for ry, rx, rc_1 in T.grid(3, 3, 64): + Conv2dOutput_1_let[ff_1] = Conv2dOutput_1_let[ff_1] + T.cast(PaddedInput_1_let[ax0_ax1_fused_ax2_fused_1 // 75 * 4928 + ry * 4928 + rx * 64 + ax0_ax1_fused_ax2_fused_1 % 75 * 64 + rc_1], "int32") * T.cast(placeholder_14[ry * 12288 + rx * 4096 + rc_1 * 64 + ff_1], "int32") + for ax3_inner_2 in T.serial(0, 64): + T_cast_5[ax0_ax1_fused_ax2_fused_1 * 64 + ax3_inner_2] = T.cast(T.cast(T.max(T.min(T.q_multiply_shift(Conv2dOutput_1_let[ax3_inner_2] + placeholder_15[ax3_inner_2], 1608879842, 31, -7, dtype="int32"), 255), 0), "uint8"), "int16") @T.prim_func - def __tvm_main__(input: T.handle, global_workspace_0_var: T.handle, output: T.handle) -> None: + def __tvm_main__(input: T.handle, global_workspace_0_var: T.Ptr[T.uint8], output: T.handle) -> None: global_workspace_0_buffer_var = T.match_buffer(global_workspace_0_var, [7920256], dtype="uint8", strides=[1], elem_offset=1, align=16) # body T.attr("default", "device_id", 0) T.attr("default", "device_type", 1) - sid_2_let: T.handle = T.address_of(T.load("uint8", global_workspace_0_buffer_var.data, 5760000), dtype="handle") - sid_6_let: T.handle = T.address_of(T.load("uint8", global_workspace_0_buffer_var.data, 0), dtype="handle") - sid_7_let: T.handle = T.address_of(T.load("uint8", global_workspace_0_buffer_var.data, 6480000), dtype="handle") - sid_8_let: T.handle = T.address_of(T.load("uint8", global_workspace_0_buffer_var.data, 6480000), dtype="handle") + sid_2_let: T.Ptr[T.int8] = T.address_of(global_workspace_0_buffer_var[5760000], dtype="handle") + sid_6_let: T.Ptr[T.int8] = T.address_of(global_workspace_0_buffer_var[0], dtype="handle") + sid_7_let: T.Ptr[T.int8] = T.address_of(global_workspace_0_buffer_var[6480000], dtype="handle") + sid_8_let: T.Ptr[T.int8] = T.address_of(global_workspace_0_buffer_var[6480000], dtype="handle") T.evaluate(T.call_extern("tvmgen_default_fused_cast_subtract_fixed_point_multiply_add_clip_cast_cast", input, T.lookup_param("p0", dtype="handle"), sid_2_let, global_workspace_0_buffer_var.data, dtype="int32")) T.evaluate(T.call_extern("tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast", sid_2_let, T.lookup_param("p3", dtype="handle"), T.lookup_param("p4", dtype="handle"), sid_8_let, global_workspace_0_buffer_var.data, dtype="int32")) T.evaluate(T.call_extern("tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast_1", sid_8_let, T.lookup_param("p5", dtype="handle"), T.lookup_param("p6", dtype="handle"), sid_7_let, global_workspace_0_buffer_var.data, dtype="int32")) @@ -509,14 +513,9 @@ def test_resnet_subgraph(): tir_mod_with_offsets_ref = ResnetStructurePlanned - # The TIR produced fails on roundtrip TVMScript testing. - # Therefore, indicates the TVMScript produced here and/or the parser - # is lacking functionality. Thus for these tests, uses a string - # version of the TVMScript for each function as a check instead. - for gv, func in tir_mod_with_offsets_ref.functions.items(): - assert str(tir_mod_with_offsets_ref[gv.name_hint].script()) == str( - tir_mod_with_offsets[gv.name_hint].script() - ) + for gv, ref_func in tir_mod_with_offsets_ref.functions.items(): + actual_func = tir_mod_with_offsets[gv.name_hint] + tvm.ir.assert_structural_equal(actual_func, ref_func) if __name__ == "__main__": diff --git a/tests/python/unittest/test_tir_usmp_utils.py b/tests/python/unittest/test_tir_usmp_utils.py index 34e526ae5173..e6add3a5cfd3 100644 --- a/tests/python/unittest/test_tir_usmp_utils.py +++ b/tests/python/unittest/test_tir_usmp_utils.py @@ -31,53 +31,53 @@ class LinearStructure: def tvmgen_default_fused_cast_subtract(placeholder_2: T.handle, placeholder_3: T.handle, T_subtract: T.handle) -> None: # function attr dict T.func_attr({"global_symbol": "tvmgen_default_fused_cast_subtract", "tir.noalias": True}) - placeholder_4 = T.match_buffer(placeholder_2, [1, 224, 224, 3], dTpe="uint8", elem_offset=0, align=128, offset_factor=1) - placeholder_5 = T.match_buffer(placeholder_3, [], dtype="int16", elem_offset=0, align=128, offset_factor=1) - T_subtract_1 = T.match_buffer(T_subtract, [1, 224, 224, 3], dtype="int16", elem_offset=0, align=128, offset_factor=1) + placeholder_4 = T.match_buffer(placeholder_2, [150528], dTpe="uint8", elem_offset=0, align=128, offset_factor=1) + placeholder_5 = T.match_buffer(placeholder_3, [1], dtype="int16", elem_offset=0, align=128, offset_factor=1) + T_subtract_1 = T.match_buffer(T_subtract, [150528], dtype="int16", elem_offset=0, align=128, offset_factor=1) # body for ax0_ax1_fused_1 in T.serial(0, 224): for ax2_1, ax3_inner_1 in T.grid(224, 3): - T.store(T_subtract_1.data, (((ax0_ax1_fused_1*672) + (ax2_1*3)) + ax3_inner_1), (T.cast(T.load("uint8", placeholder_4.data, (((ax0_ax1_fused_1*672) + (ax2_1*3)) + ax3_inner_1)), "int16") - T.load("int16", placeholder_5.data, 0)), True) + T_subtract_1[(((ax0_ax1_fused_1*672) + (ax2_1*3)) + ax3_inner_1)] = (T.cast(placeholder_4[(((ax0_ax1_fused_1*672) + (ax2_1*3)) + ax3_inner_1)], "int16") - placeholder_5[0]) @T.prim_func def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast(placeholder_62: T.handle, placeholder_63: T.handle, placeholder_64: T.handle, T_cast_20: T.handle) -> None: # function attr dict T.func_attr({"global_symbol": "tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast", "tir.noalias": True}) - placeholder_65 = T.match_buffer(placeholder_62, [1, 224, 224, 3], dtype="int16", elem_offset=0, align=128, offset_factor=1) - placeholder_66 = T.match_buffer(placeholder_63, [7, 7, 3, 64], dtype="int16", elem_offset=0, align=128, offset_factor=1) - placeholder_67 = T.match_buffer(placeholder_64, [1, 1, 1, 64], dtype="int32", elem_offset=0, align=128, offset_factor=1) - T_cast_21 = T.match_buffer(T_cast_20, [1, 112, 112, 64], dtype="uint8", elem_offset=0, align=128, offset_factor=1) + placeholder_65 = T.match_buffer(placeholder_62, [150528], dtype="int16", elem_offset=0, align=128, offset_factor=1) + placeholder_66 = T.match_buffer(placeholder_63, [9408], dtype="int16", elem_offset=0, align=128, offset_factor=1) + placeholder_67 = T.match_buffer(placeholder_64, [64], dtype="int32", elem_offset=0, align=128, offset_factor=1) + T_cast_21 = T.match_buffer(T_cast_20, [289], dtype="uint8", elem_offset=0, align=128, offset_factor=1) # body PaddedInput_7 = T.allocate([157323], "int16", "global") for i0_i1_fused_7 in T.serial(0, 229): for i2_7, i3_7 in T.grid(229, 3): - T.store(PaddedInput_7, (((i0_i1_fused_7*687) + (i2_7*3)) + i3_7), T.if_then_else(((((2 <= i0_i1_fused_7) and (i0_i1_fused_7 < 226)) and (2 <= i2_7)) and (i2_7 < 226)), T.load("int16", placeholder_65.data, ((((i0_i1_fused_7*672) + (i2_7*3)) + i3_7) - 1350)), T.int16(0), dtype="int16"), True) + PaddedInput_7[(((i0_i1_fused_7*687) + (i2_7*3)) + i3_7)] = T.if_then_else(((((2 <= i0_i1_fused_7) and (i0_i1_fused_7 < 226)) and (2 <= i2_7)) and (i2_7 < 226)), placeholder_65[((((i0_i1_fused_7*672) + (i2_7*3)) + i3_7) - 1350)], T.int16(0), dtype="int16") for ax0_ax1_fused_ax2_fused_7 in T.serial(0, 12544): Conv2dOutput_7 = T.allocate([64], "int32", "global") for ff_3 in T.serial(0, 64): - T.store(Conv2dOutput_7, ff_3, 0, True) + Conv2dOutput_7[ff_3] = 0 for ry_2, rx_2, rc_7 in T.grid(7, 7, 3): - T.store(Conv2dOutput_7, ff_3, (T.load("int32", Conv2dOutput_7, ff_3) + (T.cast(T.load("int16", PaddedInput_7, (((((T.floordiv(ax0_ax1_fused_ax2_fused_7, 112)*1374) + (ry_2*687)) + (T.floormod(ax0_ax1_fused_ax2_fused_7, 112)*6)) + (rx_2*3)) + rc_7)), "int32")*T.cast(T.load("int16", placeholder_66.data, ((((ry_2*1344) + (rx_2*192)) + (rc_7*64)) + ff_3)), "int32"))), True) + Conv2dOutput_7[ff_3] = (Conv2dOutput_7[ff_3] + (T.cast(PaddedInput_7[(((((T.floordiv(ax0_ax1_fused_ax2_fused_7, 112)*1374) + (ry_2*687)) + (T.floormod(ax0_ax1_fused_ax2_fused_7, 112)*6)) + (rx_2*3)) + rc_7)], "int32")*T.cast(placeholder_66[((((ry_2*1344) + (rx_2*192)) + (rc_7*64)) + ff_3)], "int32"))) for ax3_inner_7 in T.serial(0, 64): - T.store(T_cast_21.data, ((ax0_ax1_fused_ax2_fused_7*64) + ax3_inner_7), T.cast(T.max(T.min(T.q_multiply_shift((T.load("int32", Conv2dOutput_7, ax3_inner_7) + T.load("int32", placeholder_67.data, ax3_inner_7)), 1939887962, 31, -9, dtype="int32"), 255), 0), "uint8"), True) + T_cast_21[((ax0_ax1_fused_ax2_fused_7*64) + ax3_inner_7)] = T.cast(T.max(T.min(T.q_multiply_shift((Conv2dOutput_7[ax3_inner_7] + placeholder_67[ax3_inner_7]), 1939887962, 31, -9, dtype="int32"), 255), 0), "uint8") @T.prim_func def tvmgen_default_fused_nn_max_pool2d_cast(placeholder_28: T.handle, T_cast_6: T.handle) -> None: # function attr dict T.func_attr({"global_symbol": "tvmgen_default_fused_nn_max_pool2d_cast", "tir.noalias": True}) - placeholder_29 = T.match_buffer(placeholder_28, [1, 112, 112, 64], dtype="uint8", elem_offset=0, align=128, offset_factor=1) - T_cast_7 = T.match_buffer(T_cast_6, [1, 56, 56, 64], dtype="int16", elem_offset=0, align=128, offset_factor=1) + placeholder_29 = T.match_buffer(placeholder_28, [802816], dtype="uint8", elem_offset=0, align=128, offset_factor=1) + T_cast_7 = T.match_buffer(T_cast_6, [177], dtype="int16", elem_offset=0, align=128, offset_factor=1) # body tensor_2 = T.allocate([200704], "uint8", "global") for ax0_ax1_fused_4 in T.serial(0, 56): for ax2_4 in T.serial(0, 56): for ax3_init in T.serial(0, 64): - T.store(tensor_2, (((ax0_ax1_fused_4*3584) + (ax2_4*64)) + ax3_init), T.uint8(0), True) + tensor_2[(((ax0_ax1_fused_4*3584) + (ax2_4*64)) + ax3_init)] = T.uint8(0) for rv0_rv1_fused_1, ax3_2 in T.grid(9, 64): - T.store(tensor_2, (((ax0_ax1_fused_4*3584) + (ax2_4*64)) + ax3_2), T.max(T.load("uint8", tensor_2, (((ax0_ax1_fused_4*3584) + (ax2_4*64)) + ax3_2)), T.if_then_else(((((ax0_ax1_fused_4*2) + T.floordiv(rv0_rv1_fused_1, 3)) < 112) and (((ax2_4*2) + T.floormod(rv0_rv1_fused_1, 3)) < 112)), T.load("uint8", placeholder_29.data, (((((ax0_ax1_fused_4*14336) + (T.floordiv(rv0_rv1_fused_1, 3)*7168)) + (ax2_4*128)) + (T.floormod(rv0_rv1_fused_1, 3)*64)) + ax3_2)), T.uint8(0), dtype="uint8")), True) + tensor_2[(((ax0_ax1_fused_4*3584) + (ax2_4*64)) + ax3_2)] = T.max(tensor_2[(((ax0_ax1_fused_4*3584) + (ax2_4*64)) + ax3_2)], T.if_then_else(((((ax0_ax1_fused_4*2) + T.floordiv(rv0_rv1_fused_1, 3)) < 112) and (((ax2_4*2) + T.floormod(rv0_rv1_fused_1, 3)) < 112)), placeholder_29[(((((ax0_ax1_fused_4*14336) + (T.floordiv(rv0_rv1_fused_1, 3)*7168)) + (ax2_4*128)) + (T.floormod(rv0_rv1_fused_1, 3)*64)) + ax3_2)], T.uint8(0), dtype="uint8")) for ax0_ax1_fused_5 in T.serial(0, 56): for ax2_5, ax3_3 in T.grid(56, 64): - T.store(T_cast_7.data, (((ax0_ax1_fused_5*3584) + (ax2_5*64)) + ax3_3), T.cast(T.load("uint8", tensor_2, (((ax0_ax1_fused_5*3584) + (ax2_5*64)) + ax3_3)), "int16"), True) + T_cast_7[(((ax0_ax1_fused_5*3584) + (ax2_5*64)) + ax3_3)] = T.cast(tensor_2[(((ax0_ax1_fused_5*3584) + (ax2_5*64)) + ax3_3)], "int16") @T.prim_func def tvmgen_default_run_model(input: T.handle, output: T.handle) -> None: @@ -88,9 +88,9 @@ def tvmgen_default_run_model(input: T.handle, output: T.handle) -> None: T.attr("default", "device_type", 1) sid_9 = T.allocate([301056], "int8", "global") sid_8 = T.allocate([802816], "int8", "global") - T.evaluate(T.call_extern("tvmgen_default_fused_cast_subtract", input, T.lookup_param("p0", dtype="handle"), sid_9, dtype="int32")) - T.evaluate(T.call_extern("tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast", sid_9, T.lookup_param("p1", dtype="handle"), T.lookup_param("p2", dtype="handle"), sid_8, dtype="int32")) - T.evaluate(T.call_extern("tvmgen_default_fused_nn_max_pool2d_cast", sid_8, output, dtype="int32")) + T.evaluate(T.call_extern("tvmgen_default_fused_cast_subtract", input, T.lookup_param("p0", dtype="handle"), sid_9.data, dtype="int32")) + T.evaluate(T.call_extern("tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast", sid_9.data, T.lookup_param("p1", dtype="handle"), T.lookup_param("p2", dtype="handle"), sid_8.data, dtype="int32")) + T.evaluate(T.call_extern("tvmgen_default_fused_nn_max_pool2d_cast", sid_8.data, output, dtype="int32")) __tvm_meta__ = None # fmt: on diff --git a/tests/python/unittest/test_transform_layout.py b/tests/python/unittest/test_transform_layout.py new file mode 100755 index 000000000000..5cac01dd7f7c --- /dev/null +++ b/tests/python/unittest/test_transform_layout.py @@ -0,0 +1,498 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import functools +import sys +import pytest + +import numpy as np + +import tvm +import tvm.testing +from tvm import te +from tvm.tir.stmt_functor import post_order_visit +from tvm.driver.build_module import schedule_to_module + +dtype = tvm.testing.parameter("int32") + + +def flatten_all_indices(preflatten_shape): + def mapping(*indices): + output = 0 + for index, size in zip(indices, preflatten_shape): + output = output * size + index + return [output] + + return mapping + + +def unpack_flattened_indices(preflatten_shape): + def mapping(i): + output = [] + for dim in reversed(preflatten_shape): + output.append(i % dim) + i //= dim + return output[::-1] + + return mapping + + +def traverse(s, op, callback): + visited = set() + + def _traverse(op): + if op in visited: + return + visited.add(op) + for tensor in op.input_tensors: + _traverse(tensor.op) + callback(op) + + _traverse(op) + + +class TestCompareAgainstExplicitReshape: + A_definition_style = tvm.testing.parameter( + "explicit_reshape", + "transform_layout", + ) + B_definition_style = tvm.testing.parameter( + "explicit_reshape", + "transform_layout", + ) + + reordered_shape = tvm.testing.parameter((2, 3, 4)) + + @tvm.testing.fixture + def n_items(self, reordered_shape): + return functools.reduce(lambda x, y: x * y, reordered_shape, 1) + + @tvm.testing.fixture + def fphysical_layout(self, reordered_shape): + return unpack_flattened_indices(reordered_shape) + + @tvm.testing.fixture + def fcompute(self, A_definition_style, B_definition_style, reordered_shape, n_items, dtype): + assert A_definition_style in ["explicit_reshape", "transform_layout"] + assert B_definition_style in ["explicit_reshape", "transform_layout"] + + def func(): + if A_definition_style == "explicit_reshape": + A_input = te.placeholder(shape=reordered_shape, name="A_input", dtype=dtype) + A = te.compute( + shape=(n_items,), + fcompute=lambda i: A_input[ + i // (reordered_shape[1] * reordered_shape[2]), + (i // reordered_shape[2]) % reordered_shape[1], + i % reordered_shape[2], + ], + name="A", + ) + + elif A_definition_style == "transform_layout": + A = te.placeholder(shape=(n_items,), name="A", dtype=dtype) + A_input = A + + B = te.compute(shape=A.shape, fcompute=lambda i: A[i], name="B") + + if B_definition_style == "explicit_reshape": + B_output = te.compute( + shape=reordered_shape, + fcompute=lambda i, j, k: B[ + i * reordered_shape[1] * reordered_shape[2] + j * reordered_shape[2] + k + ], + name="B_output", + ) + elif B_definition_style == "transform_layout": + B_output = B + + return A_input, B_output + + return func + + @tvm.testing.fixture + def fschedule(self, A_definition_style, B_definition_style, fphysical_layout): + def func(outs): + outs = [outs] if isinstance(outs, te.tensor.Tensor) else outs + s = te.create_schedule([x.op for x in outs]) + + def callback(op): + if (op.name == "A" and A_definition_style == "transform_layout") or ( + op.name == "B" and B_definition_style == "transform_layout" + ): + s[op].transform_layout(fphysical_layout) + + traverse(s, outs[0].op, callback) + return s + + return func + + @tvm.testing.parametrize_targets("llvm") + def test_external_reshape( + self, target, dev, fcompute, fschedule, n_items, reordered_shape, dtype + ): + A, B = fcompute() + s = fschedule(B) + + func = tvm.build(s, [A, B], target=target, name="copy_reshape") + + a_np = np.arange(n_items).reshape(reordered_shape).astype(dtype) + b_np = np.arange(n_items).reshape(reordered_shape).astype(dtype) + a = tvm.nd.array(a_np, dev) + b = tvm.nd.empty(b_np.shape, dtype=dtype, device=dev) + + func(a, b) + + tvm.testing.assert_allclose(b.numpy(), b_np) + + @tvm.testing.parametrize_targets("llvm") + def test_internal_reshape(self, target, dev, n_items, reordered_shape, dtype, fphysical_layout): + # The reshaping of the buffer gets flattened away in + # StorageFlatten. Therefore, testing the behavior by running only + # ApplyLayoutTransforms. + logical_shape = (n_items,) + A = te.placeholder(logical_shape, name="A", dtype=dtype) + B = te.compute(shape=logical_shape, fcompute=lambda i: A[i], name="B") + C = te.compute(shape=logical_shape, fcompute=lambda i: B[i], name="C") + + s = te.create_schedule(C.op) + s[B].transform_layout(fphysical_layout) + + mod = schedule_to_module(s, [A, C]) + body = mod["main"].body + + def walk_buffer_interactions(stmt, callback): + buffer_classes = [ + tvm.tir.BufferLoad, + tvm.tir.BufferStore, + tvm.tir.BufferRealize, + ] + + def inner(node): + if (type(node) in buffer_classes) and node.buffer.name == "B": + callback(node) + + post_order_visit(stmt, inner) + + # All references to the buffer are the same object + def check_references(): + buffer_object = None + + def inner(node): + nonlocal buffer_object + if buffer_object is None: + buffer_object = node.buffer + else: + assert node.buffer.same_as(buffer_object) + + return inner + + # The buffer has the expected shape. + def check_shape(expected_shape): + def inner(node): + assert tuple(node.buffer.shape) == expected_shape + + return inner + + # Before the transform, the buffer should be in the logical shape. + walk_buffer_interactions(body, check_references()) + walk_buffer_interactions(body, check_shape(logical_shape)) + + mod = tvm.tir.transform.ApplyLayoutTransforms()(mod) + body = mod["main"].body + + # After the transform, the buffer should be in the physical shape. + walk_buffer_interactions(body, check_references()) + walk_buffer_interactions(body, check_shape(reordered_shape)) + + +class Test2DPhysicalLayout: + transform_A = tvm.testing.parameter( + by_dict={ + "2d_A": True, + "1d_A": False, + } + ) + transform_B = tvm.testing.parameter( + by_dict={ + "2d_B": True, + "1d_B": False, + } + ) + + @staticmethod + def extract_loop_vars(stmt): + output = [] + + def callback(node): + if isinstance(node, tvm.tir.For): + output.append(node.loop_var) + + post_order_visit(stmt, callback) + return output[::-1] + + def test_2d_physical(self, dtype, transform_A, transform_B): + logical_shape = (2, 3, 4) + A = te.placeholder(shape=logical_shape, dtype=dtype, name="A") + B = te.compute(shape=A.shape, fcompute=lambda i, j, k: A[i, j, k], name="B") + + s = te.create_schedule(B.op) + + if transform_A: + s[A].transform_layout(lambda i, j, k: [i, j, te.AXIS_SEPARATOR, k]) + + if transform_B: + s[B].transform_layout(lambda i, j, k: [i, j, te.AXIS_SEPARATOR, k]) + + # If the two buffers are accessed with the same indices, CSE + # will replace them with a Let binding. Since this makes it + # harder to test what the transformed indices are, disabling + # the CSE pass for this test. + with tvm.transform.PassContext(disabled_pass=["tir.CommonSubexprElimTIR"]): + mod = tvm.lower(s, [A, B]) + + i, j, k = self.extract_loop_vars(mod["main"].body) + indices_1d = [i * (logical_shape[1] * logical_shape[2]) + j * logical_shape[2] + k] + indices_2d = [i * logical_shape[1] + j, k] + + def callback(node): + if type(node) in [tvm.tir.BufferLoad, tvm.tir.BufferStore]: + name = node.buffer.name + if name == "A": + expected_indices = indices_2d if transform_A else indices_1d + elif name == "B": + expected_indices = indices_2d if transform_B else indices_1d + else: + raise RuntimeError(f"Unexpected buffer: {name}") + + tvm.ir.assert_structural_equal(expected_indices, node.indices) + + post_order_visit(mod["main"].body, callback) + + +class TestTransformedSchedules: + logical_shape = tvm.testing.parameter((4, 6, 40)) + + transform_names = [ + None, + "reverse", + "flatten_all", + "factor_last_by_4", + ] + + transform_A = tvm.testing.parameter(by_dict={f"A_{t}": t for t in transform_names}) + transform_B = tvm.testing.parameter( + by_dict={f"B_{t}": t for t in transform_names if t is not None} + ) + + after_transform = tvm.testing.parameter(None) + + def make_transform(self, logical_shape, transform_name): + if transform_name is None: + return lambda *indices: indices + elif transform_name == "reverse": + return lambda *indices: indices[::-1] + elif transform_name == "flatten_all": + return flatten_all_indices(logical_shape) + elif transform_name == "factor_last_by_4": + return lambda *indices, n: [*indices, n // 4, n % 4] + else: + raise NotImplementedError(f"Unknown transformation {transform_name}") + + def make_transformed_shape(self, logical_shape, transform_name): + if transform_name is None: + return logical_shape + elif transform_name == "reverse": + return logical_shape[::-1] + elif transform_name == "flatten_all": + num_elements = functools.reduce(lambda x, y: x * y, logical_shape, 1) + return [num_elements] + elif transform_name == "factor_last_by_4": + *indices, n = logical_shape + return [*indices, n // 4, 4] + else: + raise NotImplementedError(f"Unknown transformation {transform_name}") + + @tvm.testing.fixture + def expected_loop_order(self, logical_shape, transform_B, after_transform): + shape = self.make_transformed_shape(logical_shape, transform_B) + + if after_transform == "reorder": + shape = shape[::-1] + + elif after_transform == "split": + shape = [ + *shape[:-1], + 2, + shape[-1] // 2, + ] + + elif after_transform == "fuse": + fused_size = shape[0] if transform_B == "flatten_all" else shape[0] * shape[1] + shape = [fused_size, *shape[2:]] + + return shape + + @tvm.testing.fixture + def schedule(self, logical_shape, dtype, transform_A, transform_B, after_transform): + A = te.placeholder(shape=logical_shape, dtype=dtype, name="A") + B = te.compute(shape=A.shape, fcompute=lambda i, j, k: A[i, j, k], name="B") + + s = te.create_schedule(B.op) + + if transform_A: + s[A].transform_layout(self.make_transform(logical_shape, transform_A)) + + iter_vars = s[B].transform_layout(self.make_transform(logical_shape, transform_B)) + iter_vars = list(iter_vars) + + if after_transform == "reorder": + s[B].reorder(*iter_vars[::-1]) + + elif after_transform == "split": + s[B].split(iter_vars[-1], nparts=2) + + elif after_transform == "fuse": + to_fuse = iter_vars[:2] + s[B].fuse(*iter_vars[:2]) + + return { + "schedule": s, + "tensors": [A, B], + "iter_vars": iter_vars, + } + + def compare_tir_loop_order(self, stmt, expected_loop_order): + def collect_loops(node): + output = [] + + def callback(node): + if isinstance(node, tvm.tir.For): + output.append(node) + + post_order_visit(node, callback) + return output[::-1] + + loops = collect_loops(stmt) + loop_order = [loop.extent for loop in loops] + + np.testing.assert_array_equal(loop_order, expected_loop_order) + + def test_tir_loop_order(self, schedule, expected_loop_order): + func = tvm.lower(schedule["schedule"], schedule["tensors"])["main"] + self.compare_tir_loop_order(func.body, expected_loop_order) + + def test_te_loop_order(self, schedule, expected_loop_order): + s = schedule["schedule"] + A, B = schedule["tensors"] + iter_vars = schedule["iter_vars"] + + # No reduction axis, so all leaf_iter_vars are over the data + # array, and should have the new iteration variables. + extents = [int(iter_var.dom.extent) for iter_var in s[B].leaf_iter_vars] + np.testing.assert_array_equal(extents, expected_loop_order) + + # layout_transform should return the new iteration variables. + extents = [int(iter_var.dom.extent) for iter_var in iter_vars] + np.testing.assert_array_equal(extents, expected_loop_order) + + @pytest.mark.parametrize("after_transform", ["reorder", "split", "fuse"]) + def test_use_transformed_axes( + self, schedule, expected_loop_order, transform_A, transform_B, after_transform + ): + s = schedule["schedule"] + A, B = schedule["tensors"] + + func = tvm.lower(s, [A, B])["main"] + self.compare_tir_loop_order(func.body, expected_loop_order) + + +class TestTransformCache: + A_size = tvm.testing.parameter(16) + + transform_A = tvm.testing.parameter(by_dict={"transformA": True, "": False}) + transform_B = tvm.testing.parameter(by_dict={"transformB": True, "": False}) + cache_A = tvm.testing.parameter(by_dict={"cacheA": True, "": False}) + cache_B = tvm.testing.parameter(by_dict={"cacheB": True, "": False}) + + @tvm.testing.fixture + def schedule_args(self, target, A_size, transform_A, transform_B, cache_A, cache_B, dtype): + A = te.placeholder(shape=[A_size], dtype=dtype, name="A") + B = te.compute(A.shape, lambda i: A[i], name="B") + s = te.create_schedule(B.op) + + requires_thread_bind = "gpu" in tvm.target.Target(target).keys + thread_x = te.thread_axis("threadIdx.x") + thread_y = te.thread_axis("threadIdx.y") + thread_z = te.thread_axis("threadIdx.z") + + if cache_A: + AA = s.cache_read(A, "shared", [B]) + if requires_thread_bind: + s[AA].bind(AA.op.axis[0], thread_x) + + if cache_B: + BB = s.cache_write(B, "shared") + if requires_thread_bind: + s[BB].bind(BB.op.axis[0], thread_y) + + if transform_A: + A_axis = s[A].transform_layout(lambda i: [i // 4, i % 4]) + + if transform_B: + B_axis = s[B].transform_layout(lambda i: [i // 4, i % 4]) + else: + B_axis = B.op.axis + + if requires_thread_bind: + s[B].bind(B_axis[0], thread_z) + + return [s, [A, B]] + + @tvm.testing.fixture + def ref_data(self, A_size, dtype, transform_A, transform_B): + a_np = (100 * np.random.uniform(size=A_size)).astype(dtype) + b_np = a_np + + if transform_A: + a_np = a_np.reshape((-1, 4)) + + if transform_B: + b_np = b_np.reshape((-1, 4)) + + return a_np, b_np + + def test_lower(self, schedule_args): + tvm.lower(*schedule_args) + + def test_execute(self, target, dev, schedule_args, ref_data, dtype): + func = tvm.build(*schedule_args, target=target) + + a_np, b_np = ref_data + a = tvm.nd.array(a_np, dev) + b = tvm.nd.empty(b_np.shape, dtype=dtype, device=dev) + + func(a, b) + + if "int" in dtype: + np.testing.assert_equal(b.numpy(), b_np) + else: + tvm.testing.assert_allclose(b.numpy(), b_np) + + +if __name__ == "__main__": + sys.exit(pytest.main(sys.argv)) diff --git a/tests/python/unittest/test_tvmscript_complete.py b/tests/python/unittest/test_tvmscript_complete.py index 691e6cd9bbb6..429f54809929 100644 --- a/tests/python/unittest/test_tvmscript_complete.py +++ b/tests/python/unittest/test_tvmscript_complete.py @@ -319,12 +319,6 @@ def test_complete_alloc_buffer(): tvm.ir.assert_structural_equal(alloc_buffer_func, expect_alloc_buffer_func) -@T.prim_func -def load_var() -> None: - d = T.var("float32") - d[1] = d[1] - - if __name__ == "__main__": test_complete_matmul() test_complete_matmul_original() diff --git a/tests/python/unittest/test_tvmscript_error_report.py b/tests/python/unittest/test_tvmscript_error_report.py index 19dc81290e16..462142e2e534 100644 --- a/tests/python/unittest/test_tvmscript_error_report.py +++ b/tests/python/unittest/test_tvmscript_error_report.py @@ -332,8 +332,7 @@ def opaque_access_during_complete(a: T.handle) -> None: # error A = T.match_buffer(a, (16, 16), "float32") for i, j in T.grid(16, 16): with T.block(): - vi, vj = T.axis.remap("SS", [i, j]) - T.evaluate(T.load("float32", A.data, vi * 16 + vj)) + T.evaluate(T.call_extern("dummy_extern_function", A.data, dtype="int32")) def test_opaque_access_during_complete(): @@ -415,7 +414,7 @@ def intrin_except_unassign(a: T.handle) -> None: def intrin_except_assign(a: T.handle) -> None: A = T.match_buffer(a, (16, 16), "float32") - A[0, 0] = T.load(A, A, A) # error + A[0, 0] = A[A] # error def test_tvm_exception_catch(): diff --git a/tests/python/unittest/test_tvmscript_roundtrip.py b/tests/python/unittest/test_tvmscript_roundtrip.py index e3a70bb0c7ad..36eeac0d85b8 100644 --- a/tests/python/unittest/test_tvmscript_roundtrip.py +++ b/tests/python/unittest/test_tvmscript_roundtrip.py @@ -19,463 +19,374 @@ import pytest import tvm +import tvm.testing from tvm import tir from tvm.script import tir as T import numpy as np -@tvm.script.ir_module -class Module1: - @T.prim_func - def mmult(A: T.handle, B: T.handle, C: T.handle) -> None: - # function attr dict - T.func_attr({"global_symbol": "mmult", "tir.noalias": True}) - # buffer definition - C_global = T.buffer_decl([1024, 1024], elem_offset=0, align=128, offset_factor=1) - packedB = T.buffer_decl([32, 1024, 32], elem_offset=0, align=128, offset_factor=1) - A_1 = T.match_buffer(A, [1024, 1024], elem_offset=0, align=128, offset_factor=1) - B_1 = T.match_buffer(B, [1024, 1024], elem_offset=0, align=128, offset_factor=1) - C_1 = T.match_buffer(C, [1024, 1024], elem_offset=0, align=128, offset_factor=1) - # body - T.realize(packedB[0:32, 0:1024, 0:32], "") - for x in T.parallel(0, 32): - for y in T.serial(0, 1024): - for z in T.vectorized(0, 32): - packedB[x, y, z] = B_1[y, ((x * 32) + z)] - T.realize(C_1[0:1024, 0:1024], "") - for x_outer in T.parallel(0, 32): - for y_outer in T.serial(0, 32): - T.realize( - C_global[ - (x_outer * 32) : ((x_outer * 32) + 32), - (y_outer * 32) : ((y_outer * 32) + 32), - ], - "global", - ) - for x_c_init in T.serial(0, 32): - for y_c_init in T.vectorized(0, 32): +def opt_gemm_normalize(): + @tvm.script.ir_module + class Module: + @T.prim_func + def mmult(A: T.handle, B: T.handle, C: T.handle) -> None: + # function attr dict + T.func_attr({"global_symbol": "mmult", "tir.noalias": True}) + # buffer definition + C_global = T.buffer_decl([1024, 1024], elem_offset=0, align=128, offset_factor=1) + packedB = T.buffer_decl([32, 1024, 32], elem_offset=0, align=128, offset_factor=1) + A_1 = T.match_buffer(A, [1024, 1024], elem_offset=0, align=128, offset_factor=1) + B_1 = T.match_buffer(B, [1024, 1024], elem_offset=0, align=128, offset_factor=1) + C_1 = T.match_buffer(C, [1024, 1024], elem_offset=0, align=128, offset_factor=1) + # body + T.realize(packedB[0:32, 0:1024, 0:32], "") + for x in T.parallel(0, 32): + for y in T.serial(0, 1024): + for z in T.vectorized(0, 32): + packedB[x, y, z] = B_1[y, ((x * 32) + z)] + T.realize(C_1[0:1024, 0:1024], "") + for x_outer in T.parallel(0, 32): + for y_outer in T.serial(0, 32): + T.realize( C_global[ - (x_c_init + (x_outer * 32)), (y_c_init + (y_outer * 32)) - ] = T.float32(0) - for k_outer in T.serial(0, 256): - for x_c in T.serial(0, 32): - for k_inner in T.unroll(0, 4): - for y_c in T.vectorized(0, 32): - C_global[(x_c + (x_outer * 32)), (y_c + (y_outer * 32))] = C_global[ - (x_c + (x_outer * 32)), (y_c + (y_outer * 32)) - ] + ( - A_1[(x_c + (x_outer * 32)), (k_inner + (k_outer * 4))] - * packedB[ - T.floordiv((y_c + (y_outer * 32)), 32), - (k_inner + (k_outer * 4)), - T.floormod((y_c + (y_outer * 32)), 32), - ] - ) - for x_inner in T.serial(0, 32): - for y_inner in T.serial(0, 32): - C_1[(x_inner + (x_outer * 32)), (y_inner + (y_outer * 32))] = C_global[ - (x_inner + (x_outer * 32)), (y_inner + (y_outer * 32)) - ] - - -def test_opt_gemm_normalize(): - mod = Module1 - rt_mod = tvm.script.from_source(mod.script(show_meta=True)) - tvm.ir.assert_structural_equal(mod, rt_mod, True) - - -@tvm.script.ir_module -class Module2: - @T.prim_func - def mmult(A: T.handle, B: T.handle, C: T.handle) -> None: - # function attr dict - T.func_attr({"global_symbol": "mmult", "tir.noalias": True}) - A_1 = T.match_buffer(A, [1024, 1024], elem_offset=0, align=128, offset_factor=1) - B_1 = T.match_buffer(B, [1024, 1024], elem_offset=0, align=128, offset_factor=1) - C_1 = T.match_buffer(C, [1024, 1024], elem_offset=0, align=128, offset_factor=1) - # body - packedB = T.allocate([32768], "float32x32", "global") - for x in T.parallel(0, 32): - for y in T.serial(0, 1024): - T.store( - packedB, - T.ramp(((x * 32768) + (y * 32)), 1, 32), - T.load( - "float32x32", - B_1.data, - T.ramp(((y * 1024) + (x * 32)), 1, 32), - T.broadcast(True, 32), - ), - T.broadcast(True, 32), - ) - for x_outer in T.parallel(0, 32): - C_global = T.allocate([1024], "float32", "global") - for y_outer in T.serial(0, 32): - for x_c_init in T.serial(0, 32): - T.store( - C_global, - T.ramp((x_c_init * 32), 1, 32), - T.broadcast(T.float32(0), 32), - T.broadcast(True, 32), + (x_outer * 32) : ((x_outer * 32) + 32), + (y_outer * 32) : ((y_outer * 32) + 32), + ], + "global", ) - for k_outer in T.serial(0, 256): - for x_c in T.serial(0, 32): - T.store( - C_global, - T.ramp((x_c * 32), 1, 32), - ( - T.load( - "float32x32", - C_global, - T.ramp((x_c * 32), 1, 32), - T.broadcast(True, 32), - ) - + ( - T.broadcast( - T.load( - "float32", - A_1.data, - (((x_outer * 32768) + (x_c * 1024)) + (k_outer * 4)), - ), - 32, - ) - * T.load( - "float32x32", - packedB, - T.ramp(((y_outer * 32768) + (k_outer * 128)), 1, 32), - T.broadcast(True, 32), - ) - ) - ), - T.broadcast(True, 32), - ) - T.store( - C_global, - T.ramp((x_c * 32), 1, 32), - ( - T.load( - "float32x32", - C_global, - T.ramp((x_c * 32), 1, 32), - T.broadcast(True, 32), - ) - + ( - T.broadcast( - T.load( - "float32", - A_1.data, - ( - (((x_outer * 32768) + (x_c * 1024)) + (k_outer * 4)) - + 1 - ), - ), - 32, - ) - * T.load( - "float32x32", - packedB, - T.ramp((((y_outer * 32768) + (k_outer * 128)) + 32), 1, 32), - T.broadcast(True, 32), + for x_c_init in T.serial(0, 32): + for y_c_init in T.vectorized(0, 32): + C_global[ + (x_c_init + (x_outer * 32)), (y_c_init + (y_outer * 32)) + ] = T.float32(0) + for k_outer in T.serial(0, 256): + for x_c in T.serial(0, 32): + for k_inner in T.unroll(0, 4): + for y_c in T.vectorized(0, 32): + C_global[ + (x_c + (x_outer * 32)), (y_c + (y_outer * 32)) + ] = C_global[(x_c + (x_outer * 32)), (y_c + (y_outer * 32))] + ( + A_1[(x_c + (x_outer * 32)), (k_inner + (k_outer * 4))] + * packedB[ + T.floordiv((y_c + (y_outer * 32)), 32), + (k_inner + (k_outer * 4)), + T.floormod((y_c + (y_outer * 32)), 32), + ] ) + for x_inner in T.serial(0, 32): + for y_inner in T.serial(0, 32): + C_1[(x_inner + (x_outer * 32)), (y_inner + (y_outer * 32))] = C_global[ + (x_inner + (x_outer * 32)), (y_inner + (y_outer * 32)) + ] + + return Module + + +def opt_gemm_lower(): + @tvm.script.ir_module + class Module: + @T.prim_func + def mmult(A: T.handle, B: T.handle, C: T.handle) -> None: + # function attr dict + T.func_attr({"global_symbol": "mmult", "tir.noalias": True}) + A_1 = T.match_buffer(A, [1024 * 1024], elem_offset=0, align=128, offset_factor=1) + B_1 = T.match_buffer(B, [1024, 1024], elem_offset=0, align=128, offset_factor=1) + C_1 = T.match_buffer(C, [1024 * 1024], elem_offset=0, align=128, offset_factor=1) + # body + packedB = T.allocate([32768], "float32", "global") + for x in T.parallel(0, 32): + for y in T.serial(0, 1024): + packedB[T.ramp(((x * 32768) + (y * 32)), 1, 32)] = B_1[y, T.ramp(x * 32, 1, 32)] + for x_outer in T.parallel(0, 32): + C_global = T.allocate([1024], "float32", "global") + for y_outer in T.serial(0, 32): + for x_c_init in T.serial(0, 32): + C_global[T.ramp((x_c_init * 32), 1, 32)] = T.broadcast(T.float32(0), 32) + for k_outer in T.serial(0, 256): + for x_c in T.serial(0, 32): + C_global[T.ramp((x_c * 32), 1, 32)] = C_global[ + T.ramp((x_c * 32), 1, 32) + ] + ( + T.broadcast( + A_1[ + (((x_outer * 32768) + (x_c * 1024)) + (k_outer * 4)), + ], + 32, ) - ), - T.broadcast(True, 32), - ) - T.store( - C_global, - T.ramp((x_c * 32), 1, 32), - ( - T.load( - "float32x32", - C_global, - T.ramp((x_c * 32), 1, 32), - T.broadcast(True, 32), + * packedB[T.ramp(((y_outer * 32768) + (k_outer * 128)), 1, 32)] + ) + C_global[T.ramp((x_c * 32), 1, 32)] = C_global[ + T.ramp((x_c * 32), 1, 32) + ] + ( + T.broadcast( + A_1[ + ((((x_outer * 32768) + (x_c * 1024)) + (k_outer * 4)) + 1), + ], + 32, ) - + ( - T.broadcast( - T.load( - "float32", - A_1.data, - ( - (((x_outer * 32768) + (x_c * 1024)) + (k_outer * 4)) - + 2 - ), - ), - 32, - ) - * T.load( - "float32x32", - packedB, - T.ramp((((y_outer * 32768) + (k_outer * 128)) + 64), 1, 32), - T.broadcast(True, 32), - ) + * packedB[ + T.ramp((((y_outer * 32768) + (k_outer * 128)) + 32), 1, 32) + ] + ) + C_global[T.ramp((x_c * 32), 1, 32)] = C_global[ + T.ramp((x_c * 32), 1, 32) + ] + ( + T.broadcast( + A_1[ + ((((x_outer * 32768) + (x_c * 1024)) + (k_outer * 4)) + 2), + ], + 32, ) - ), - T.broadcast(True, 32), - ) - T.store( - C_global, - T.ramp((x_c * 32), 1, 32), - ( - T.load( - "float32x32", - C_global, - T.ramp((x_c * 32), 1, 32), - T.broadcast(True, 32), + * packedB[ + T.ramp((((y_outer * 32768) + (k_outer * 128)) + 64), 1, 32) + ] + ) + C_global[T.ramp((x_c * 32), 1, 32)] = C_global[ + T.ramp((x_c * 32), 1, 32) + ] + ( + T.broadcast( + A_1[ + ((((x_outer * 32768) + (x_c * 1024)) + (k_outer * 4)) + 3), + ], + 32, ) - + ( - T.broadcast( - T.load( - "float32", - A_1.data, - ( - (((x_outer * 32768) + (x_c * 1024)) + (k_outer * 4)) - + 3 - ), - ), - 32, - ) - * T.load( - "float32x32", - packedB, - T.ramp((((y_outer * 32768) + (k_outer * 128)) + 96), 1, 32), - T.broadcast(True, 32), - ) + * packedB[ + T.ramp((((y_outer * 32768) + (k_outer * 128)) + 96), 1, 32) + ] + ) + for x_inner in T.serial(0, 32): + for y_inner in T.serial(0, 32): + C_1[ + ( + (((x_outer * 32768) + (x_inner * 1024)) + (y_outer * 32)) + + y_inner ) - ), - T.broadcast(True, 32), - ) - for x_inner in T.serial(0, 32): - for y_inner in T.serial(0, 32): - C_1.data[ - ((((x_outer * 32768) + (x_inner * 1024)) + (y_outer * 32)) + y_inner) - ] = T.load("float32", C_global, ((x_inner * 32) + y_inner)) - - -def test_opt_gemm_lower(): - mod = Module2 - rt_mod = tvm.script.from_source(mod.script(show_meta=True)) - tvm.ir.assert_structural_equal(mod, rt_mod, True) - - -@tvm.script.ir_module -class Module3: - @T.prim_func - def mmult( - args: T.handle, - arg_type_ids: T.handle, - num_args: T.int32, - out_ret_value: T.handle, - out_ret_tcode: T.handle, - ) -> T.int32: - # function attr dict - T.func_attr( - { - "tir.noalias": True, - "global_symbol": "mmult", - "tir.is_entry_func": True, - "calling_conv": 1, - } - ) - # var definition - C_global = T.buffer_var("float32", "global") - packedB = T.buffer_var("float32", "global") - # body - assert num_args == 3, "mmult: num_args should be 3" - arg0: T.handle = T.tvm_struct_get(args, 0, 12, dtype="handle") - arg0_code: T.int32 = T.load("int32", arg_type_ids, 0) - arg1: T.handle = T.tvm_struct_get(args, 1, 12, dtype="handle") - arg1_code: T.int32 = T.load("int32", arg_type_ids, 1) - arg2: T.handle = T.tvm_struct_get(args, 2, 12, dtype="handle") - arg2_code: T.int32 = T.load("int32", arg_type_ids, 2) - A: T.handle = T.tvm_struct_get(arg0, 0, 1, dtype="handle") - T.attr(A, "storage_alignment", 128) - arg0_shape: T.handle = T.tvm_struct_get(arg0, 0, 2, dtype="handle") - arg0_strides: T.handle = T.tvm_struct_get(arg0, 0, 3, dtype="handle") - dev_id: T.int32 = T.tvm_struct_get(arg0, 0, 9, dtype="int32") - B: T.handle = T.tvm_struct_get(arg1, 0, 1, dtype="handle") - T.attr(B, "storage_alignment", 128) - arg1_shape: T.handle = T.tvm_struct_get(arg1, 0, 2, dtype="handle") - arg1_strides: T.handle = T.tvm_struct_get(arg1, 0, 3, dtype="handle") - C: T.handle = T.tvm_struct_get(arg2, 0, 1, dtype="handle") - T.attr(C, "storage_alignment", 128) - arg2_shape: T.handle = T.tvm_struct_get(arg2, 0, 2, dtype="handle") - arg2_strides: T.handle = T.tvm_struct_get(arg2, 0, 3, dtype="handle") - assert (((arg0_code == 3) or (arg0_code == 13)) or (arg0_code == 7)) or ( - arg0_code == 4 - ), "mmult: Expect arg[0] to be pointer" - assert (((arg1_code == 3) or (arg1_code == 13)) or (arg1_code == 7)) or ( - arg1_code == 4 - ), "mmult: Expect arg[1] to be pointer" - assert (((arg2_code == 3) or (arg2_code == 13)) or (arg2_code == 7)) or ( - arg2_code == 4 - ), "mmult: Expect arg[2] to be pointer" - assert 2 == T.tvm_struct_get(arg0, 0, 4, dtype="int32"), "arg0.ndim is expected to equal 2" - assert 2 == T.tvm_struct_get(arg0, 0, 4, dtype="int32"), "arg0.ndim is expected to equal 2" - assert ( - (T.tvm_struct_get(arg0, 0, 5, dtype="uint8") == T.uint8(2)) - and (T.tvm_struct_get(arg0, 0, 6, dtype="uint8") == T.uint8(32)) - ) and ( - T.tvm_struct_get(arg0, 0, 7, dtype="uint16") == T.uint16(1) - ), "arg0.dtype is expected to be float32" - assert 1024 == T.cast( - T.load("int64", arg0_shape, 0), "int32" - ), "Argument arg0.shape[0] has an unsatisfied constraint" - assert 1024 == T.cast( - T.load("int64", arg0_shape, 1), "int32" - ), "Argument arg0.shape[1] has an unsatisfied constraint" - if not (T.isnullptr(arg0_strides, dtype="bool")): - assert (1 == T.cast(T.load("int64", arg0_strides, 1), "int32")) and ( - 1024 == T.cast(T.load("int64", arg0_strides, 0), "int32") - ), "arg0.strides: expected to be compact array" - T.evaluate(0) - assert T.uint64(0) == T.tvm_struct_get( - arg0, 0, 8, dtype="uint64" - ), "Argument arg0.byte_offset has an unsatisfied constraint" - assert 1 == T.tvm_struct_get( - arg0, 0, 10, dtype="int32" - ), "Argument arg0.device_type has an unsatisfied constraint" - assert 2 == T.tvm_struct_get(arg1, 0, 4, dtype="int32"), "arg1.ndim is expected to equal 2" - assert 2 == T.tvm_struct_get(arg1, 0, 4, dtype="int32"), "arg1.ndim is expected to equal 2" - assert ( - (T.tvm_struct_get(arg1, 0, 5, dtype="uint8") == T.uint8(2)) - and (T.tvm_struct_get(arg1, 0, 6, dtype="uint8") == T.uint8(32)) - ) and ( - T.tvm_struct_get(arg1, 0, 7, dtype="uint16") == T.uint16(1) - ), "arg1.dtype is expected to be float32" - assert 1024 == T.cast( - T.load("int64", arg1_shape, 0), "int32" - ), "Argument arg1.shape[0] has an unsatisfied constraint" - assert 1024 == T.cast( - T.load("int64", arg1_shape, 1), "int32" - ), "Argument arg1.shape[1] has an unsatisfied constraint" - if not (T.isnullptr(arg1_strides, dtype="bool")): - assert (1 == T.cast(T.load("int64", arg1_strides, 1), "int32")) and ( - 1024 == T.cast(T.load("int64", arg1_strides, 0), "int32") - ), "arg1.strides: expected to be compact array" - T.evaluate(0) - assert T.uint64(0) == T.tvm_struct_get( - arg1, 0, 8, dtype="uint64" - ), "Argument arg1.byte_offset has an unsatisfied constraint" - assert 1 == T.tvm_struct_get( - arg1, 0, 10, dtype="int32" - ), "Argument arg1.device_type has an unsatisfied constraint" - assert dev_id == T.tvm_struct_get( - arg1, 0, 9, dtype="int32" - ), "Argument arg1.device_id has an unsatisfied constraint" - assert 2 == T.tvm_struct_get(arg2, 0, 4, dtype="int32"), "arg2.ndim is expected to equal 2" - assert 2 == T.tvm_struct_get(arg2, 0, 4, dtype="int32"), "arg2.ndim is expected to equal 2" - assert ( - (T.tvm_struct_get(arg2, 0, 5, dtype="uint8") == T.uint8(2)) - and (T.tvm_struct_get(arg2, 0, 6, dtype="uint8") == T.uint8(32)) - ) and ( - T.tvm_struct_get(arg2, 0, 7, dtype="uint16") == T.uint16(1) - ), "arg2.dtype is expected to be float32" - assert 1024 == T.cast( - T.load("int64", arg2_shape, 0), "int32" - ), "Argument arg2.shape[0] has an unsatisfied constraint" - assert 1024 == T.cast( - T.load("int64", arg2_shape, 1), "int32" - ), "Argument arg2.shape[1] has an unsatisfied constraint" - if not (T.isnullptr(arg2_strides, dtype="bool")): - assert (1 == T.cast(T.load("int64", arg2_strides, 1), "int32")) and ( - 1024 == T.cast(T.load("int64", arg2_strides, 0), "int32") - ), "arg2.strides: expected to be compact array" - T.evaluate(0) - assert T.uint64(0) == T.tvm_struct_get( - arg2, 0, 8, dtype="uint64" - ), "Argument arg2.byte_offset has an unsatisfied constraint" - assert 1 == T.tvm_struct_get( - arg2, 0, 10, dtype="int32" - ), "Argument arg2.device_type has an unsatisfied constraint" - assert dev_id == T.tvm_struct_get( - arg2, 0, 9, dtype="int32" - ), "Argument arg2.device_id has an unsatisfied constraint" - T.attr(0, "compute_scope", "mmult_compute_") - T.attr(packedB, "storage_scope", "global") - T.attr(packedB, "storage_alignment", 128) - with T.let( - packedB, - T.TVMBackendAllocWorkspace(1, dev_id, T.uint64(4194304), 2, 32, dtype="handle"), - ): - if T.isnullptr(packedB, dtype="bool"): - T.evaluate(T.tvm_throw_last_error(dtype="int32")) - for x in T.parallel(0, 32): - for y in T.serial(0, 1024): - T.store( - packedB, - T.ramp(((x * 32768) + (y * 32)), 1, 32), - T.load( - "float32x32", - B, - T.ramp(((y * 1024) + (x * 32)), 1, 32), - T.broadcast(True, 32), + ] = C_global[((x_inner * 32) + y_inner)] + + return Module + + +def opt_gemm_mod_host(): + @tvm.script.ir_module + class Module: + @T.prim_func + def mmult( + args: T.handle, + arg_type_ids: T.handle, + num_args: T.int32, + out_ret_value: T.handle, + out_ret_tcode: T.handle, + ) -> T.int32: + # function attr dict + T.func_attr( + { + "tir.noalias": True, + "global_symbol": "mmult", + "tir.is_entry_func": True, + "calling_conv": 1, + } + ) + # buffer definition + buf_type_ids = T.match_buffer(arg_type_ids, [3], dtype="int32") + + packedB = T.buffer_decl([32768], dtype="float32") + C_global = T.buffer_decl([1024], dtype="float32") + # var definition + # C_global = T.buffer_var("float32", "global") + # packedB = T.buffer_var("float32", "global") + # body + assert num_args == 3, "mmult: num_args should be 3" + arg0: T.handle = T.tvm_struct_get(args, 0, 12, dtype="handle") + arg0_code: T.int32 = buf_type_ids[0] + arg1: T.handle = T.tvm_struct_get(args, 1, 12, dtype="handle") + arg1_code: T.int32 = buf_type_ids[1] + arg2: T.handle = T.tvm_struct_get(args, 2, 12, dtype="handle") + arg2_code: T.int32 = buf_type_ids[2] + + A_data: T.Ptr[T.int32] = T.tvm_struct_get(arg0, 0, 1, dtype="handle") + T.attr(A_data, "storage_alignment", 128) + A: T.Buffer = T.buffer_decl([1024 * 1024], dtype="int32", data=A_data) + buf0_shape_data: T.Ptr[T.int32] = T.tvm_struct_get(arg0, 0, 2, dtype="handle") + buf0_shape: T.Buffer = T.buffer_decl([2], dtype="int32", data=buf0_shape_data) + buf0_strides_data: T.Ptr[T.int32] = T.tvm_struct_get(arg0, 0, 3, dtype="handle") + buf0_strides: T.Buffer = T.buffer_decl([2], dtype="int32", data=buf0_strides_data) + + dev_id: T.int32 = T.tvm_struct_get(arg0, 0, 9, dtype="int32") + + B_data: T.Ptr[T.int32] = T.tvm_struct_get(arg1, 0, 1, dtype="handle") + T.attr(B_data, "storage_alignment", 128) + B: T.Buffer = T.buffer_decl([1024 * 1024], dtype="int32", data=B_data) + buf1_shape_data: T.Ptr[T.int32] = T.tvm_struct_get(arg1, 0, 2, dtype="handle") + buf1_shape: T.Buffer = T.buffer_decl([2], dtype="int32", data=buf1_shape_data) + buf1_strides_data: T.Ptr[T.int32] = T.tvm_struct_get(arg1, 0, 3, dtype="handle") + buf1_strides: T.Buffer = T.buffer_decl([2], dtype="int32", data=buf1_strides_data) + + C_data: T.Ptr[T.int32] = T.tvm_struct_get(arg2, 0, 1, dtype="handle") + T.attr(C_data, "storage_alignment", 128) + C: T.Buffer = T.buffer_decl([1024 * 1024], dtype="int32", data=C_data) + buf2_shape_data: T.Ptr[T.int32] = T.tvm_struct_get(arg2, 0, 2, dtype="handle") + buf2_shape: T.Buffer = T.buffer_decl([2], dtype="int32", data=buf2_shape_data) + buf2_strides_data: T.Ptr[T.int32] = T.tvm_struct_get(arg2, 0, 3, dtype="handle") + buf2_strides: T.Buffer = T.buffer_decl([2], dtype="int32", data=buf2_strides_data) + + assert (((arg0_code == 3) or (arg0_code == 13)) or (arg0_code == 7)) or ( + arg0_code == 4 + ), "mmult: Expect arg[0] to be pointer" + assert (((arg1_code == 3) or (arg1_code == 13)) or (arg1_code == 7)) or ( + arg1_code == 4 + ), "mmult: Expect arg[1] to be pointer" + assert (((arg2_code == 3) or (arg2_code == 13)) or (arg2_code == 7)) or ( + arg2_code == 4 + ), "mmult: Expect arg[2] to be pointer" + assert 2 == T.tvm_struct_get( + arg0, 0, 4, dtype="int32" + ), "arg0.ndim is expected to equal 2" + assert 2 == T.tvm_struct_get( + arg0, 0, 4, dtype="int32" + ), "arg0.ndim is expected to equal 2" + assert ( + (T.tvm_struct_get(arg0, 0, 5, dtype="uint8") == T.uint8(2)) + and (T.tvm_struct_get(arg0, 0, 6, dtype="uint8") == T.uint8(32)) + ) and ( + T.tvm_struct_get(arg0, 0, 7, dtype="uint16") == T.uint16(1) + ), "arg0.dtype is expected to be float32" + assert 1024 == T.cast( + buf0_shape[0], "int32" + ), "Argument arg0.shape[0] has an unsatisfied constraint" + assert 1024 == T.cast( + buf0_shape[1], "int32" + ), "Argument arg0.shape[1] has an unsatisfied constraint" + if not (T.isnullptr(buf0_strides.data, dtype="bool")): + assert (1 == T.cast(buf0_strides[1], "int32")) and ( + 1024 == T.cast(buf0_strides[0], "int32") + ), "arg0.strides: expected to be compact array" + T.evaluate(0) + assert T.uint64(0) == T.tvm_struct_get( + arg0, 0, 8, dtype="uint64" + ), "Argument arg0.byte_offset has an unsatisfied constraint" + assert 1 == T.tvm_struct_get( + arg0, 0, 10, dtype="int32" + ), "Argument arg0.device_type has an unsatisfied constraint" + assert 2 == T.tvm_struct_get( + arg1, 0, 4, dtype="int32" + ), "arg1.ndim is expected to equal 2" + assert 2 == T.tvm_struct_get( + arg1, 0, 4, dtype="int32" + ), "arg1.ndim is expected to equal 2" + assert ( + (T.tvm_struct_get(arg1, 0, 5, dtype="uint8") == T.uint8(2)) + and (T.tvm_struct_get(arg1, 0, 6, dtype="uint8") == T.uint8(32)) + ) and ( + T.tvm_struct_get(arg1, 0, 7, dtype="uint16") == T.uint16(1) + ), "arg1.dtype is expected to be float32" + assert 1024 == T.cast( + buf1_shape[0], "int32" + ), "Argument arg1.shape[0] has an unsatisfied constraint" + assert 1024 == T.cast( + buf1_shape[1], "int32" + ), "Argument arg1.shape[1] has an unsatisfied constraint" + if not (T.isnullptr(buf1_strides.data, dtype="bool")): + assert (1 == T.cast(buf1_strides[1], "int32")) and ( + 1024 == T.cast(buf1_strides[0], "int32") + ), "arg1.strides: expected to be compact array" + T.evaluate(0) + assert T.uint64(0) == T.tvm_struct_get( + arg1, 0, 8, dtype="uint64" + ), "Argument arg1.byte_offset has an unsatisfied constraint" + assert 1 == T.tvm_struct_get( + arg1, 0, 10, dtype="int32" + ), "Argument arg1.device_type has an unsatisfied constraint" + assert dev_id == T.tvm_struct_get( + arg1, 0, 9, dtype="int32" + ), "Argument arg1.device_id has an unsatisfied constraint" + assert 2 == T.tvm_struct_get( + arg2, 0, 4, dtype="int32" + ), "arg2.ndim is expected to equal 2" + assert 2 == T.tvm_struct_get( + arg2, 0, 4, dtype="int32" + ), "arg2.ndim is expected to equal 2" + assert ( + (T.tvm_struct_get(arg2, 0, 5, dtype="uint8") == T.uint8(2)) + and (T.tvm_struct_get(arg2, 0, 6, dtype="uint8") == T.uint8(32)) + ) and ( + T.tvm_struct_get(arg2, 0, 7, dtype="uint16") == T.uint16(1) + ), "arg2.dtype is expected to be float32" + assert 1024 == T.cast( + buf2_shape[0], "int32" + ), "Argument arg2.shape[0] has an unsatisfied constraint" + assert 1024 == T.cast( + buf2_shape[1], "int32" + ), "Argument arg2.shape[1] has an unsatisfied constraint" + if not (T.isnullptr(buf2_strides.data, dtype="bool")): + assert (1 == T.cast(buf2_strides[1], "int32")) and ( + 1024 == T.cast(buf2_strides[0], "int32") + ), "arg2.strides: expected to be compact array" + T.evaluate(0) + assert T.uint64(0) == T.tvm_struct_get( + arg2, 0, 8, dtype="uint64" + ), "Argument arg2.byte_offset has an unsatisfied constraint" + assert 1 == T.tvm_struct_get( + arg2, 0, 10, dtype="int32" + ), "Argument arg2.device_type has an unsatisfied constraint" + assert dev_id == T.tvm_struct_get( + arg2, 0, 9, dtype="int32" + ), "Argument arg2.device_id has an unsatisfied constraint" + T.attr(0, "compute_scope", "mmult_compute_") + T.attr(packedB.data, "storage_scope", "global") + T.attr(packedB.data, "storage_alignment", 128) + with T.let( + packedB.data, + T.TVMBackendAllocWorkspace(1, dev_id, T.uint64(4194304), 2, 32, dtype="handle"), + ): + if T.isnullptr(packedB.data, dtype="bool"): + T.evaluate(T.tvm_throw_last_error(dtype="int32")) + for x in T.parallel(0, 32): + for y in T.serial(0, 1024): + packedB[T.ramp(((x * 32768) + (y * 32)), 1, 32)] = B[ + T.ramp(((y * 1024) + (x * 32)), 1, 32) + ] + for x_outer in T.parallel(0, 32): + T.attr(C_global.data, "storage_scope", "global") + T.attr(C_global.data, "storage_alignment", 128) + with T.let( + C_global.data, + T.TVMBackendAllocWorkspace( + 1, dev_id, T.uint64(4096), 2, 32, dtype="handle" ), - T.broadcast(True, 32), - ) - for x_outer in T.parallel(0, 32): - T.attr(C_global, "storage_scope", "global") - T.attr(C_global, "storage_alignment", 128) - with T.let( - C_global, - T.TVMBackendAllocWorkspace(1, dev_id, T.uint64(4096), 2, 32, dtype="handle"), - ): - if T.isnullptr(C_global, dtype="bool"): - T.evaluate(T.tvm_throw_last_error(dtype="int32")) - for y_outer in T.serial(0, 32): - for x_c_init in T.serial(0, 32): - T.store( - C_global, - T.ramp((x_c_init * 32), 1, 32), - T.broadcast(T.float32(0), 32), - T.broadcast(True, 32), - ) - for k_outer in T.serial(0, 256): - for x_c in T.serial(0, 32): - T.store( - C_global, - T.ramp((x_c * 32), 1, 32), - T.call_llvm_pure_intrin( + ): + if T.isnullptr(C_global.data, dtype="bool"): + T.evaluate(T.tvm_throw_last_error(dtype="int32")) + for y_outer in T.serial(0, 32): + for x_c_init in T.serial(0, 32): + C_global[T.ramp((x_c_init * 32), 1, 32)] = T.broadcast( + T.float32(0), 32 + ) + for k_outer in T.serial(0, 256): + for x_c in T.serial(0, 32): + C_global[T.ramp((x_c * 32), 1, 32)] = T.call_llvm_pure_intrin( T.uint32(97), T.uint32(3), T.broadcast( - T.load( - "float32", - A, + A[ ( ((x_outer * 32768) + (x_c * 1024)) + (k_outer * 4) ), - ), + ], 32, ), - T.load( - "float32x32", - packedB, - T.ramp(((y_outer * 32768) + (k_outer * 128)), 1, 32), - T.broadcast(True, 32), - ), - T.load( - "float32x32", - C_global, - T.ramp((x_c * 32), 1, 32), - T.broadcast(True, 32), - ), + packedB[ + T.ramp(((y_outer * 32768) + (k_outer * 128)), 1, 32) + ], + C_global[T.ramp((x_c * 32), 1, 32)], dtype="float32x32", - ), - T.broadcast(True, 32), - ) - T.store( - C_global, - T.ramp((x_c * 32), 1, 32), - T.call_llvm_pure_intrin( + ) + C_global[T.ramp((x_c * 32), 1, 32)] = T.call_llvm_pure_intrin( T.uint32(97), T.uint32(3), T.broadcast( - T.load( - "float32", - A, + A[ ( ( ((x_outer * 32768) + (x_c * 1024)) @@ -483,37 +394,22 @@ def mmult( ) + 1 ), - ), + ], 32, ), - T.load( - "float32x32", - packedB, + packedB[ T.ramp( (((y_outer * 32768) + (k_outer * 128)) + 32), 1, 32 - ), - T.broadcast(True, 32), - ), - T.load( - "float32x32", - C_global, - T.ramp((x_c * 32), 1, 32), - T.broadcast(True, 32), - ), + ) + ], + C_global[T.ramp((x_c * 32), 1, 32)], dtype="float32x32", - ), - T.broadcast(True, 32), - ) - T.store( - C_global, - T.ramp((x_c * 32), 1, 32), - T.call_llvm_pure_intrin( + ) + C_global[T.ramp((x_c * 32), 1, 32)] = T.call_llvm_pure_intrin( T.uint32(97), T.uint32(3), T.broadcast( - T.load( - "float32", - A, + A[ ( ( ((x_outer * 32768) + (x_c * 1024)) @@ -521,37 +417,22 @@ def mmult( ) + 2 ), - ), + ], 32, ), - T.load( - "float32x32", - packedB, + packedB[ T.ramp( (((y_outer * 32768) + (k_outer * 128)) + 64), 1, 32 - ), - T.broadcast(True, 32), - ), - T.load( - "float32x32", - C_global, - T.ramp((x_c * 32), 1, 32), - T.broadcast(True, 32), - ), + ) + ], + C_global[T.ramp((x_c * 32), 1, 32)], dtype="float32x32", - ), - T.broadcast(True, 32), - ) - T.store( - C_global, - T.ramp((x_c * 32), 1, 32), - T.call_llvm_pure_intrin( + ) + C_global[T.ramp((x_c * 32), 1, 32)] = T.call_llvm_pure_intrin( T.uint32(97), T.uint32(3), T.broadcast( - T.load( - "float32", - A, + A[ ( ( ((x_outer * 32768) + (x_c * 1024)) @@ -559,379 +440,263 @@ def mmult( ) + 3 ), - ), + ], 32, ), - T.load( - "float32x32", - packedB, + packedB[ T.ramp( (((y_outer * 32768) + (k_outer * 128)) + 96), 1, 32 - ), - T.broadcast(True, 32), - ), - T.load( - "float32x32", - C_global, - T.ramp((x_c * 32), 1, 32), - T.broadcast(True, 32), - ), + ) + ], + C_global[T.ramp((x_c * 32), 1, 32)], dtype="float32x32", - ), - T.broadcast(True, 32), - ) - for x_inner in T.serial(0, 32): - for y_inner in T.serial(0, 32): - C[ - ( - (((x_outer * 32768) + (x_inner * 1024)) + (y_outer * 32)) - + y_inner ) - ] = T.load("float32", C_global, ((x_inner * 32) + y_inner)) - if T.TVMBackendFreeWorkspace(1, dev_id, C_global, dtype="int32") != 0: - T.evaluate(T.tvm_throw_last_error(dtype="int32")) - if T.TVMBackendFreeWorkspace(1, dev_id, packedB, dtype="int32") != 0: - T.evaluate(T.tvm_throw_last_error(dtype="int32")) + for x_inner in T.serial(0, 32): + for y_inner in T.serial(0, 32): + C[ + ( + ( + ((x_outer * 32768) + (x_inner * 1024)) + + (y_outer * 32) + ) + + y_inner + ) + ] = C_global[((x_inner * 32) + y_inner)] + if T.TVMBackendFreeWorkspace(1, dev_id, C_global.data, dtype="int32") != 0: + T.evaluate(T.tvm_throw_last_error(dtype="int32")) + if T.TVMBackendFreeWorkspace(1, dev_id, packedB.data, dtype="int32") != 0: + T.evaluate(T.tvm_throw_last_error(dtype="int32")) + return Module -def test_opt_gemm_mod_host(): - mod = Module3 - rt_mod = tvm.script.from_source(mod.script(show_meta=True)) - tvm.ir.assert_structural_equal(mod, rt_mod, True) +def opt_conv_tensorcore_normalize(): + @T.prim_func + def func(A: T.handle, W: T.handle, Conv: T.handle) -> None: + # function attr dict + T.func_attr({"global_symbol": "default_function", "tir.noalias": True}) + # var definition + bx = T.env_thread("blockIdx.x") + by = T.env_thread("blockIdx.y") + bz = T.env_thread("blockIdx.z") + tx = T.env_thread("threadIdx.x") + ty = T.env_thread("threadIdx.y") + tz = T.env_thread("threadIdx.z") + # buffer definition + Apad_shared = T.buffer_decl( + [16, 16, 16, 16, 16, 16], dtype="float16", elem_offset=0, align=128, offset_factor=1 + ) + Apad_shared_wmma_matrix_a = T.buffer_decl( + [16, 16, 16, 16, 16, 16], dtype="float16", elem_offset=0, align=128, offset_factor=1 + ) + BA = T.buffer_decl( + [16, 16], dtype="float16", scope="wmma.matrix_a", align=32, offset_factor=256 + ) + BB = T.buffer_decl( + [16, 16], dtype="float16", scope="wmma.matrix_b", align=32, offset_factor=256 + ) + BC = T.buffer_decl([16, 16], scope="wmma.accumulator", align=32, offset_factor=256) + Conv_wmma_accumulator = T.buffer_decl( + [16, 14, 14, 32, 16, 16], elem_offset=0, align=128, offset_factor=1 + ) + W_shared = T.buffer_decl( + [3, 3, 16, 32, 16, 16], dtype="float16", elem_offset=0, align=128, offset_factor=1 + ) + W_shared_wmma_matrix_b = T.buffer_decl( + [3, 3, 16, 32, 16, 16], dtype="float16", elem_offset=0, align=128, offset_factor=1 + ) + buffer = T.buffer_decl( + [16, 16], dtype="float16", scope="shared", align=32, offset_factor=256 + ) + buffer_1 = T.buffer_decl( + [16, 16], dtype="float16", scope="wmma.matrix_a", align=32, offset_factor=256 + ) + buffer_2 = T.buffer_decl( + [16, 16], dtype="float16", scope="shared", align=32, offset_factor=256 + ) + buffer_3 = T.buffer_decl( + [16, 16], dtype="float16", scope="wmma.matrix_b", align=32, offset_factor=256 + ) + buffer_4 = T.buffer_decl([16, 16], scope="wmma.accumulator", align=32, offset_factor=256) + buffer_5 = T.buffer_decl([16, 16], align=32, offset_factor=256) + A_1 = T.match_buffer( + A, [16, 14, 14, 16, 16, 16], dtype="float16", elem_offset=0, align=128, offset_factor=1 + ) + W_1 = T.match_buffer( + W, [3, 3, 16, 32, 16, 16], dtype="float16", elem_offset=0, align=128, offset_factor=1 + ) + Conv_1 = T.match_buffer( + Conv, [16, 14, 14, 32, 16, 16], elem_offset=0, align=128, offset_factor=1 + ) + # body + T.realize(Conv_1[0:16, 0:14, 0:14, 0:32, 0:16, 0:16], "") + T.launch_thread(bz, 196) + T.launch_thread(bx, 2) + T.launch_thread(by, 4) + T.launch_thread(ty, 4) + T.launch_thread(tz, 2) + T.realize( + Conv_wmma_accumulator[ + ((bx * 8) + (ty * 2)) : (((bx * 8) + (ty * 2)) + 2), + T.floordiv(bz, 14) : (T.floordiv(bz, 14) + 1), + T.floormod(bz, 14) : (T.floormod(bz, 14) + 1), + ((by * 8) + (tz * 4)) : (((by * 8) + (tz * 4)) + 4), + 0:16, + 0:16, + ], + "wmma.accumulator", + ) + for n_c_init in T.serial(0, 2): + for o_c_init in T.serial(0, 4): + T.attr( + [BC, Conv_wmma_accumulator], + "buffer_bind_scope", + T.tvm_tuple( + (n_c_init + ((bx * 8) + (ty * 2))), + 1, + T.floordiv(bz, 14), + 1, + T.floormod(bz, 14), + 1, + (o_c_init + ((by * 8) + (tz * 4))), + 1, + 0, + 16, + 0, + 16, + dtype="handle", + ), + ) + T.evaluate( + T.tvm_fill_fragment( + BC.data, + 16, + 16, + 16, + T.floordiv(BC.elem_offset, 256), + T.float32(0), + dtype="handle", + ) + ) -@T.prim_func -def opt_conv_tensorcore_normalize(A: T.handle, W: T.handle, Conv: T.handle) -> None: - # function attr dict - T.func_attr({"global_symbol": "default_function", "tir.noalias": True}) - # var definition - bx = T.env_thread("blockIdx.x") - by = T.env_thread("blockIdx.y") - bz = T.env_thread("blockIdx.z") - tx = T.env_thread("threadIdx.x") - ty = T.env_thread("threadIdx.y") - tz = T.env_thread("threadIdx.z") - # buffer definition - Apad_shared = T.buffer_decl( - [16, 16, 16, 16, 16, 16], dtype="float16", elem_offset=0, align=128, offset_factor=1 - ) - Apad_shared_wmma_matrix_a = T.buffer_decl( - [16, 16, 16, 16, 16, 16], dtype="float16", elem_offset=0, align=128, offset_factor=1 - ) - BA = T.buffer_decl( - [16, 16], dtype="float16", scope="wmma.matrix_a", align=32, offset_factor=256 - ) - BB = T.buffer_decl( - [16, 16], dtype="float16", scope="wmma.matrix_b", align=32, offset_factor=256 - ) - BC = T.buffer_decl([16, 16], scope="wmma.accumulator", align=32, offset_factor=256) - Conv_wmma_accumulator = T.buffer_decl( - [16, 14, 14, 32, 16, 16], elem_offset=0, align=128, offset_factor=1 - ) - W_shared = T.buffer_decl( - [3, 3, 16, 32, 16, 16], dtype="float16", elem_offset=0, align=128, offset_factor=1 - ) - W_shared_wmma_matrix_b = T.buffer_decl( - [3, 3, 16, 32, 16, 16], dtype="float16", elem_offset=0, align=128, offset_factor=1 - ) - buffer = T.buffer_decl([16, 16], dtype="float16", scope="shared", align=32, offset_factor=256) - buffer_1 = T.buffer_decl( - [16, 16], dtype="float16", scope="wmma.matrix_a", align=32, offset_factor=256 - ) - buffer_2 = T.buffer_decl([16, 16], dtype="float16", scope="shared", align=32, offset_factor=256) - buffer_3 = T.buffer_decl( - [16, 16], dtype="float16", scope="wmma.matrix_b", align=32, offset_factor=256 - ) - buffer_4 = T.buffer_decl([16, 16], scope="wmma.accumulator", align=32, offset_factor=256) - buffer_5 = T.buffer_decl([16, 16], align=32, offset_factor=256) - A_1 = T.match_buffer( - A, [16, 14, 14, 16, 16, 16], dtype="float16", elem_offset=0, align=128, offset_factor=1 - ) - W_1 = T.match_buffer( - W, [3, 3, 16, 32, 16, 16], dtype="float16", elem_offset=0, align=128, offset_factor=1 - ) - Conv_1 = T.match_buffer( - Conv, [16, 14, 14, 32, 16, 16], elem_offset=0, align=128, offset_factor=1 - ) - # body - T.realize(Conv_1[0:16, 0:14, 0:14, 0:32, 0:16, 0:16], "") - T.launch_thread(bz, 196) - T.launch_thread(bx, 2) - T.launch_thread(by, 4) - T.launch_thread(ty, 4) - T.launch_thread(tz, 2) - T.realize( - Conv_wmma_accumulator[ - ((bx * 8) + (ty * 2)) : (((bx * 8) + (ty * 2)) + 2), - T.floordiv(bz, 14) : (T.floordiv(bz, 14) + 1), - T.floormod(bz, 14) : (T.floormod(bz, 14) + 1), - ((by * 8) + (tz * 4)) : (((by * 8) + (tz * 4)) + 4), - 0:16, - 0:16, - ], - "wmma.accumulator", - ) - for n_c_init in T.serial(0, 2): - for o_c_init in T.serial(0, 4): - T.attr( - [BC, Conv_wmma_accumulator], - "buffer_bind_scope", - T.tvm_tuple( - (n_c_init + ((bx * 8) + (ty * 2))), - 1, - T.floordiv(bz, 14), - 1, - T.floormod(bz, 14), - 1, - (o_c_init + ((by * 8) + (tz * 4))), - 1, - 0, - 16, - 0, - 16, - dtype="handle", - ), - ) - T.evaluate( - T.tvm_fill_fragment( - BC.data, - 16, - 16, - 16, - T.floordiv(BC.elem_offset, 256), - T.float32(0), - dtype="handle", + for ic_outer in T.serial(0, 8): + for kh in T.serial(0, 3): + T.realize( + Apad_shared[ + (bx * 8) : ((bx * 8) + 8), + (T.floordiv(bz, 14) + kh) : ((T.floordiv(bz, 14) + kh) + 1), + T.floormod(bz, 14) : (T.floormod(bz, 14) + 3), + (ic_outer * 2) : ((ic_outer * 2) + 2), + 0:16, + 0:16, + ], + "shared", ) - ) - for ic_outer in T.serial(0, 8): - for kh in T.serial(0, 3): - T.realize( - Apad_shared[ - (bx * 8) : ((bx * 8) + 8), - (T.floordiv(bz, 14) + kh) : ((T.floordiv(bz, 14) + kh) + 1), - T.floormod(bz, 14) : (T.floormod(bz, 14) + 3), - (ic_outer * 2) : ((ic_outer * 2) + 2), - 0:16, - 0:16, - ], - "shared", - ) - for ax2 in T.serial(0, 3): - for ax3 in T.serial(0, 2): - for ax4_ax5_fused_outer in T.serial(0, 8): - T.launch_thread(tx, 32) - Apad_shared[ - ((tz + (ty * 2)) + (bx * 8)), - (T.floordiv(bz, 14) + kh), - (ax2 + T.floormod(bz, 14)), - (ax3 + (ic_outer * 2)), - T.floordiv((tx + (ax4_ax5_fused_outer * 32)), 16), - T.floormod((tx + (ax4_ax5_fused_outer * 32)), 16), - ] = T.if_then_else( - ( - ( - ( - ((T.floordiv(bz, 14) + kh) >= 1) - and (((T.floordiv(bz, 14) + kh) - 1) < 14) - ) - and ((ax2 + T.floormod(bz, 14)) >= 1) - ) - and (((ax2 + T.floormod(bz, 14)) - 1) < 14) - ), - A_1[ + for ax2 in T.serial(0, 3): + for ax3 in T.serial(0, 2): + for ax4_ax5_fused_outer in T.serial(0, 8): + T.launch_thread(tx, 32) + Apad_shared[ ((tz + (ty * 2)) + (bx * 8)), - ((T.floordiv(bz, 14) + kh) - 1), - ((ax2 + T.floormod(bz, 14)) - 1), + (T.floordiv(bz, 14) + kh), + (ax2 + T.floormod(bz, 14)), (ax3 + (ic_outer * 2)), T.floordiv((tx + (ax4_ax5_fused_outer * 32)), 16), T.floormod((tx + (ax4_ax5_fused_outer * 32)), 16), + ] = T.if_then_else( + ( + ( + ( + ((T.floordiv(bz, 14) + kh) >= 1) + and (((T.floordiv(bz, 14) + kh) - 1) < 14) + ) + and ((ax2 + T.floormod(bz, 14)) >= 1) + ) + and (((ax2 + T.floormod(bz, 14)) - 1) < 14) + ), + A_1[ + ((tz + (ty * 2)) + (bx * 8)), + ((T.floordiv(bz, 14) + kh) - 1), + ((ax2 + T.floormod(bz, 14)) - 1), + (ax3 + (ic_outer * 2)), + T.floordiv((tx + (ax4_ax5_fused_outer * 32)), 16), + T.floormod((tx + (ax4_ax5_fused_outer * 32)), 16), + ], + T.float16(0), + dtype="float16", + ) + T.realize( + W_shared[ + kh : (kh + 1), + 0:3, + (ic_outer * 2) : ((ic_outer * 2) + 2), + (by * 8) : ((by * 8) + 8), + 0:16, + 0:16, + ], + "shared", + ) + for ax1 in T.serial(0, 3): + for ax2_1 in T.serial(0, 2): + T.launch_thread(tx, 32) + for ax4_ax5_fused_inner in T.vectorized(0, 8): + W_shared[ + kh, + ax1, + (ax2_1 + (ic_outer * 2)), + ((tz + (ty * 2)) + (by * 8)), + T.floordiv((ax4_ax5_fused_inner + (tx * 8)), 16), + T.floormod((ax4_ax5_fused_inner + (tx * 8)), 16), + ] = W_1[ + kh, + ax1, + (ax2_1 + (ic_outer * 2)), + ((tz + (ty * 2)) + (by * 8)), + T.floordiv((ax4_ax5_fused_inner + (tx * 8)), 16), + T.floormod((ax4_ax5_fused_inner + (tx * 8)), 16), + ] + for ic_inner in T.serial(0, 2): + for kw in T.serial(0, 3): + T.realize( + Apad_shared_wmma_matrix_a[ + ((bx * 8) + (ty * 2)) : (((bx * 8) + (ty * 2)) + 2), + (T.floordiv(bz, 14) + kh) : ((T.floordiv(bz, 14) + kh) + 1), + (kw + T.floormod(bz, 14)) : ((kw + T.floormod(bz, 14)) + 1), + ((ic_outer * 2) + ic_inner) : (((ic_outer * 2) + ic_inner) + 1), + 0:16, + 0:16, ], - T.float16(0), - dtype="float16", + "wmma.matrix_a", ) - T.realize( - W_shared[ - kh : (kh + 1), - 0:3, - (ic_outer * 2) : ((ic_outer * 2) + 2), - (by * 8) : ((by * 8) + 8), - 0:16, - 0:16, - ], - "shared", - ) - for ax1 in T.serial(0, 3): - for ax2_1 in T.serial(0, 2): - T.launch_thread(tx, 32) - for ax4_ax5_fused_inner in T.vectorized(0, 8): - W_shared[ - kh, - ax1, - (ax2_1 + (ic_outer * 2)), - ((tz + (ty * 2)) + (by * 8)), - T.floordiv((ax4_ax5_fused_inner + (tx * 8)), 16), - T.floormod((ax4_ax5_fused_inner + (tx * 8)), 16), - ] = W_1[ - kh, - ax1, - (ax2_1 + (ic_outer * 2)), - ((tz + (ty * 2)) + (by * 8)), - T.floordiv((ax4_ax5_fused_inner + (tx * 8)), 16), - T.floormod((ax4_ax5_fused_inner + (tx * 8)), 16), - ] - for ic_inner in T.serial(0, 2): - for kw in T.serial(0, 3): - T.realize( - Apad_shared_wmma_matrix_a[ - ((bx * 8) + (ty * 2)) : (((bx * 8) + (ty * 2)) + 2), - (T.floordiv(bz, 14) + kh) : ((T.floordiv(bz, 14) + kh) + 1), - (kw + T.floormod(bz, 14)) : ((kw + T.floormod(bz, 14)) + 1), - ((ic_outer * 2) + ic_inner) : (((ic_outer * 2) + ic_inner) + 1), - 0:16, - 0:16, - ], - "wmma.matrix_a", - ) - for ax0 in T.serial(0, 2): - T.attr( - [buffer, Apad_shared], - "buffer_bind_scope", - T.tvm_tuple( - (ax0 + ((bx * 8) + (ty * 2))), - 1, - (T.floordiv(bz, 14) + kh), - 1, - (kw + T.floormod(bz, 14)), - 1, - ((ic_outer * 2) + ic_inner), - 1, - 0, - 16, - 0, - 16, - dtype="handle", - ), - ) - T.attr( - [buffer_1, Apad_shared_wmma_matrix_a], - "buffer_bind_scope", - T.tvm_tuple( - (ax0 + ((bx * 8) + (ty * 2))), - 1, - (T.floordiv(bz, 14) + kh), - 1, - (kw + T.floormod(bz, 14)), - 1, - ((ic_outer * 2) + ic_inner), - 1, - 0, - 16, - 0, - 16, - dtype="handle", - ), - ) - T.evaluate( - T.tvm_load_matrix_sync( - buffer_1.data, - 16, - 16, - 16, - T.floordiv(buffer_1.elem_offset, 256), - T.tvm_access_ptr( - T.type_annotation(dtype="float16"), - buffer.data, - buffer.elem_offset, - 256, + for ax0 in T.serial(0, 2): + T.attr( + [buffer, Apad_shared], + "buffer_bind_scope", + T.tvm_tuple( + (ax0 + ((bx * 8) + (ty * 2))), 1, - dtype="handle", - ), - 16, - "row_major", - dtype="handle", - ) - ) - T.realize( - W_shared_wmma_matrix_b[ - kh : (kh + 1), - kw : (kw + 1), - ((ic_outer * 2) + ic_inner) : (((ic_outer * 2) + ic_inner) + 1), - ((by * 8) + (tz * 4)) : (((by * 8) + (tz * 4)) + 4), - 0:16, - 0:16, - ], - "wmma.matrix_b", - ) - for ax3_1 in T.serial(0, 4): - T.attr( - [buffer_2, W_shared], - "buffer_bind_scope", - T.tvm_tuple( - kh, - 1, - kw, - 1, - ((ic_outer * 2) + ic_inner), - 1, - (ax3_1 + ((by * 8) + (tz * 4))), - 1, - 0, - 16, - 0, - 16, - dtype="handle", - ), - ) - T.attr( - [buffer_3, W_shared_wmma_matrix_b], - "buffer_bind_scope", - T.tvm_tuple( - kh, - 1, - kw, - 1, - ((ic_outer * 2) + ic_inner), - 1, - (ax3_1 + ((by * 8) + (tz * 4))), - 1, - 0, - 16, - 0, - 16, - dtype="handle", - ), - ) - T.evaluate( - T.tvm_load_matrix_sync( - buffer_3.data, - 16, - 16, - 16, - T.floordiv(buffer_3.elem_offset, 256), - T.tvm_access_ptr( - T.type_annotation(dtype="float16"), - buffer_2.data, - buffer_2.elem_offset, - 256, + (T.floordiv(bz, 14) + kh), 1, + (kw + T.floormod(bz, 14)), + 1, + ((ic_outer * 2) + ic_inner), + 1, + 0, + 16, + 0, + 16, dtype="handle", ), - 16, - "row_major", - dtype="handle", ) - ) - for n_c in T.serial(0, 2): - for o_c in T.serial(0, 4): T.attr( - [BA, Apad_shared_wmma_matrix_a], + [buffer_1, Apad_shared_wmma_matrix_a], "buffer_bind_scope", T.tvm_tuple( - (n_c + ((bx * 8) + (ty * 2))), + (ax0 + ((bx * 8) + (ty * 2))), 1, (T.floordiv(bz, 14) + kh), 1, - (T.floormod(bz, 14) + kw), + (kw + T.floormod(bz, 14)), 1, ((ic_outer * 2) + ic_inner), 1, @@ -942,8 +707,40 @@ def opt_conv_tensorcore_normalize(A: T.handle, W: T.handle, Conv: T.handle) -> N dtype="handle", ), ) + T.evaluate( + T.tvm_load_matrix_sync( + buffer_1.data, + 16, + 16, + 16, + T.floordiv(buffer_1.elem_offset, 256), + T.tvm_access_ptr( + T.type_annotation(dtype="float16"), + buffer.data, + buffer.elem_offset, + 256, + 1, + dtype="handle", + ), + 16, + "row_major", + dtype="handle", + ) + ) + T.realize( + W_shared_wmma_matrix_b[ + kh : (kh + 1), + kw : (kw + 1), + ((ic_outer * 2) + ic_inner) : (((ic_outer * 2) + ic_inner) + 1), + ((by * 8) + (tz * 4)) : (((by * 8) + (tz * 4)) + 4), + 0:16, + 0:16, + ], + "wmma.matrix_b", + ) + for ax3_1 in T.serial(0, 4): T.attr( - [BB, W_shared_wmma_matrix_b], + [buffer_2, W_shared], "buffer_bind_scope", T.tvm_tuple( kh, @@ -952,7 +749,7 @@ def opt_conv_tensorcore_normalize(A: T.handle, W: T.handle, Conv: T.handle) -> N 1, ((ic_outer * 2) + ic_inner), 1, - (o_c + ((by * 8) + (tz * 4))), + (ax3_1 + ((by * 8) + (tz * 4))), 1, 0, 16, @@ -962,16 +759,16 @@ def opt_conv_tensorcore_normalize(A: T.handle, W: T.handle, Conv: T.handle) -> N ), ) T.attr( - [BC, Conv_wmma_accumulator], + [buffer_3, W_shared_wmma_matrix_b], "buffer_bind_scope", T.tvm_tuple( - (n_c + ((bx * 8) + (ty * 2))), + kh, 1, - T.floordiv(bz, 14), + kw, 1, - T.floormod(bz, 14), + ((ic_outer * 2) + ic_inner), 1, - (o_c + ((by * 8) + (tz * 4))), + (ax3_1 + ((by * 8) + (tz * 4))), 1, 0, 16, @@ -981,748 +778,851 @@ def opt_conv_tensorcore_normalize(A: T.handle, W: T.handle, Conv: T.handle) -> N ), ) T.evaluate( - T.tvm_mma_sync( - BC.data, - T.floordiv(BC.elem_offset, 256), - BA.data, - T.floordiv(BA.elem_offset, 256), - BB.data, - T.floordiv(BB.elem_offset, 256), - BC.data, - T.floordiv(BC.elem_offset, 256), + T.tvm_load_matrix_sync( + buffer_3.data, + 16, + 16, + 16, + T.floordiv(buffer_3.elem_offset, 256), + T.tvm_access_ptr( + T.type_annotation(dtype="float16"), + buffer_2.data, + buffer_2.elem_offset, + 256, + 1, + dtype="handle", + ), + 16, + "row_major", dtype="handle", ) ) - for n_inner in T.serial(0, 2): - for o_inner in T.serial(0, 4): - T.attr( - [buffer_4, Conv_wmma_accumulator], - "buffer_bind_scope", - T.tvm_tuple( - ((((bx * 4) + ty) * 2) + n_inner), - 1, - T.floordiv(bz, 14), - 1, - T.floormod(bz, 14), - 1, - ((((by * 2) + tz) * 4) + o_inner), - 1, - 0, - 16, - 0, - 16, - dtype="handle", - ), - ) - T.attr( - [buffer_5, Conv_1], - "buffer_bind_scope", - T.tvm_tuple( - ((((bx * 4) + ty) * 2) + n_inner), - 1, - T.floordiv(bz, 14), - 1, - T.floormod(bz, 14), - 1, - ((((by * 2) + tz) * 4) + o_inner), - 1, - 0, - 16, - 0, - 16, - dtype="handle", - ), - ) - T.evaluate( - T.tvm_store_matrix_sync( - buffer_4.data, - 16, - 16, - 16, - T.floordiv(buffer_4.elem_offset, 256), - T.tvm_access_ptr( - T.type_annotation(dtype="float32"), - buffer_5.data, - buffer_5.elem_offset, - 256, - 2, + for n_c in T.serial(0, 2): + for o_c in T.serial(0, 4): + T.attr( + [BA, Apad_shared_wmma_matrix_a], + "buffer_bind_scope", + T.tvm_tuple( + (n_c + ((bx * 8) + (ty * 2))), + 1, + (T.floordiv(bz, 14) + kh), + 1, + (T.floormod(bz, 14) + kw), + 1, + ((ic_outer * 2) + ic_inner), + 1, + 0, + 16, + 0, + 16, + dtype="handle", + ), + ) + T.attr( + [BB, W_shared_wmma_matrix_b], + "buffer_bind_scope", + T.tvm_tuple( + kh, + 1, + kw, + 1, + ((ic_outer * 2) + ic_inner), + 1, + (o_c + ((by * 8) + (tz * 4))), + 1, + 0, + 16, + 0, + 16, + dtype="handle", + ), + ) + T.attr( + [BC, Conv_wmma_accumulator], + "buffer_bind_scope", + T.tvm_tuple( + (n_c + ((bx * 8) + (ty * 2))), + 1, + T.floordiv(bz, 14), + 1, + T.floormod(bz, 14), + 1, + (o_c + ((by * 8) + (tz * 4))), + 1, + 0, + 16, + 0, + 16, + dtype="handle", + ), + ) + T.evaluate( + T.tvm_mma_sync( + BC.data, + T.floordiv(BC.elem_offset, 256), + BA.data, + T.floordiv(BA.elem_offset, 256), + BB.data, + T.floordiv(BB.elem_offset, 256), + BC.data, + T.floordiv(BC.elem_offset, 256), + dtype="handle", + ) + ) + for n_inner in T.serial(0, 2): + for o_inner in T.serial(0, 4): + T.attr( + [buffer_4, Conv_wmma_accumulator], + "buffer_bind_scope", + T.tvm_tuple( + ((((bx * 4) + ty) * 2) + n_inner), + 1, + T.floordiv(bz, 14), + 1, + T.floormod(bz, 14), + 1, + ((((by * 2) + tz) * 4) + o_inner), + 1, + 0, + 16, + 0, + 16, dtype="handle", ), - 16, - "row_major", - dtype="handle", ) - ) - + T.attr( + [buffer_5, Conv_1], + "buffer_bind_scope", + T.tvm_tuple( + ((((bx * 4) + ty) * 2) + n_inner), + 1, + T.floordiv(bz, 14), + 1, + T.floormod(bz, 14), + 1, + ((((by * 2) + tz) * 4) + o_inner), + 1, + 0, + 16, + 0, + 16, + dtype="handle", + ), + ) + T.evaluate( + T.tvm_store_matrix_sync( + buffer_4.data, + 16, + 16, + 16, + T.floordiv(buffer_4.elem_offset, 256), + T.tvm_access_ptr( + T.type_annotation(dtype="float32"), + buffer_5.data, + buffer_5.elem_offset, + 256, + 2, + dtype="handle", + ), + 16, + "row_major", + dtype="handle", + ) + ) -def test_opt_conv_tensorcore_normalize(): - mod = opt_conv_tensorcore_normalize - rt_mod = tvm.script.from_source(mod.script(show_meta=True)) - tvm.ir.assert_structural_equal(mod, rt_mod, True) + return func -@T.prim_func -def opt_conv_tensorcore_lower(A: T.handle, W: T.handle, Conv: T.handle) -> None: - # function attr dict - T.func_attr({"global_symbol": "default_function", "tir.noalias": True}) - # body - A_1 = T.match_buffer( - A, [16, 14, 14, 16, 16, 16], dtype="float16", elem_offset=0, align=128, offset_factor=1 - ) - W_1 = T.match_buffer( - W, [3, 3, 16, 32, 16, 16], dtype="float16", elem_offset=0, align=128, offset_factor=1 - ) - Conv_1 = T.match_buffer( - Conv, [16, 14, 14, 32, 16, 16], elem_offset=0, align=128, offset_factor=1 - ) - bx = T.env_thread("blockIdx.x") - by = T.env_thread("blockIdx.y") - bz = T.env_thread("blockIdx.z") - tx = T.env_thread("threadIdx.x") - ty = T.env_thread("threadIdx.y") - tz = T.env_thread("threadIdx.z") - T.launch_thread(bz, 196) - Conv_wmma_accumulator = T.allocate([2048], "float32", "wmma.accumulator") - Apad_shared = T.allocate([12288], "float16", "shared") - W_shared = T.allocate([12288], "float16", "shared") - Apad_shared_wmma_matrix_a = T.allocate([512], "float16", "wmma.matrix_a") - W_shared_wmma_matrix_b = T.allocate([1024], "float16", "wmma.matrix_b") - T.launch_thread(bx, 2) - T.launch_thread(by, 4) - T.launch_thread(ty, 4) - T.launch_thread(tz, 2) - T.evaluate( - T.tvm_fill_fragment(Conv_wmma_accumulator, 16, 16, 16, 0, T.float32(0), dtype="handle") - ) - T.evaluate( - T.tvm_fill_fragment(Conv_wmma_accumulator, 16, 16, 16, 1, T.float32(0), dtype="handle") - ) - T.evaluate( - T.tvm_fill_fragment(Conv_wmma_accumulator, 16, 16, 16, 2, T.float32(0), dtype="handle") - ) - T.evaluate( - T.tvm_fill_fragment(Conv_wmma_accumulator, 16, 16, 16, 3, T.float32(0), dtype="handle") - ) - T.evaluate( - T.tvm_fill_fragment(Conv_wmma_accumulator, 16, 16, 16, 4, T.float32(0), dtype="handle") - ) - T.evaluate( - T.tvm_fill_fragment(Conv_wmma_accumulator, 16, 16, 16, 5, T.float32(0), dtype="handle") - ) - T.evaluate( - T.tvm_fill_fragment(Conv_wmma_accumulator, 16, 16, 16, 6, T.float32(0), dtype="handle") - ) - T.evaluate( - T.tvm_fill_fragment(Conv_wmma_accumulator, 16, 16, 16, 7, T.float32(0), dtype="handle") - ) - for ic_outer in T.serial(0, 8): - for kh in T.serial(0, 3): - for ax2 in T.serial(0, 3): - with T.launch_thread(tx, 32): - Apad_shared[ - ((((ty * 3072) + (tz * 1536)) + (ax2 * 512)) + tx) - ] = T.if_then_else( - ( +def opt_conv_tensorcore_lower(): + @T.prim_func + def func( + A: T.Buffer[(16, 14, 14, 16, 16, 16), "float16"], + W: T.Buffer[(3, 3, 16, 32, 16, 16), "float16"], + Conv: T.Buffer[(16, 14, 14, 32, 16, 16), "float32"], + ) -> None: + # function attr dict + T.func_attr({"global_symbol": "default_function", "tir.noalias": True}) + # body + A_1 = T.buffer_decl([12845056], dtype="float16", data=A.data) + W_1 = T.buffer_decl([1179648], dtype="float16", data=W.data) + Conv_1 = T.buffer_decl([25690112], data=Conv.data) + bx = T.env_thread("blockIdx.x") + by = T.env_thread("blockIdx.y") + bz = T.env_thread("blockIdx.z") + tx = T.env_thread("threadIdx.x") + ty = T.env_thread("threadIdx.y") + tz = T.env_thread("threadIdx.z") + T.launch_thread(bz, 196) + Conv_wmma_accumulator = T.allocate([2048], "float32", "wmma.accumulator") + Apad_shared = T.allocate([12288], "float16", "shared") + W_shared = T.allocate([12288], "float16", "shared") + Apad_shared_wmma_matrix_a = T.allocate([512], "float16", "wmma.matrix_a") + W_shared_wmma_matrix_b = T.allocate([1024], "float16", "wmma.matrix_b") + T.launch_thread(bx, 2) + T.launch_thread(by, 4) + T.launch_thread(ty, 4) + T.launch_thread(tz, 2) + T.evaluate( + T.tvm_fill_fragment( + Conv_wmma_accumulator.data, 16, 16, 16, 0, T.float32(0), dtype="handle" + ) + ) + T.evaluate( + T.tvm_fill_fragment( + Conv_wmma_accumulator.data, 16, 16, 16, 1, T.float32(0), dtype="handle" + ) + ) + T.evaluate( + T.tvm_fill_fragment( + Conv_wmma_accumulator.data, 16, 16, 16, 2, T.float32(0), dtype="handle" + ) + ) + T.evaluate( + T.tvm_fill_fragment( + Conv_wmma_accumulator.data, 16, 16, 16, 3, T.float32(0), dtype="handle" + ) + ) + T.evaluate( + T.tvm_fill_fragment( + Conv_wmma_accumulator.data, 16, 16, 16, 4, T.float32(0), dtype="handle" + ) + ) + T.evaluate( + T.tvm_fill_fragment( + Conv_wmma_accumulator.data, 16, 16, 16, 5, T.float32(0), dtype="handle" + ) + ) + T.evaluate( + T.tvm_fill_fragment( + Conv_wmma_accumulator.data, 16, 16, 16, 6, T.float32(0), dtype="handle" + ) + ) + T.evaluate( + T.tvm_fill_fragment( + Conv_wmma_accumulator.data, 16, 16, 16, 7, T.float32(0), dtype="handle" + ) + ) + for ic_outer in T.serial(0, 8): + for kh in T.serial(0, 3): + for ax2 in T.serial(0, 3): + with T.launch_thread(tx, 32): + Apad_shared[ + ((((ty * 3072) + (tz * 1536)) + (ax2 * 512)) + tx) + ] = T.if_then_else( ( ( - (1 <= (T.floordiv(bz, 14) + kh)) - and ((T.floordiv(bz, 14) + kh) < 15) + ( + (1 <= (T.floordiv(bz, 14) + kh)) + and ((T.floordiv(bz, 14) + kh) < 15) + ) + and (1 <= (ax2 + T.floormod(bz, 14))) ) - and (1 <= (ax2 + T.floormod(bz, 14))) - ) - and ((ax2 + T.floormod(bz, 14)) < 15) - ), - T.load( - "float16", - A_1.data, - ( + and ((ax2 + T.floormod(bz, 14)) < 15) + ), + A_1[ ( ( ( ( ( ( - ((bx * 6422528) + (ty * 1605632)) - + (tz * 802816) + ( + ((bx * 6422528) + (ty * 1605632)) + + (tz * 802816) + ) + + (kh * 57344) ) - + (kh * 57344) + + (bz * 4096) ) - + (bz * 4096) + + (ax2 * 4096) ) - + (ax2 * 4096) + + (ic_outer * 512) ) - + (ic_outer * 512) + + tx ) - + tx - ) - - 61440 - ), - ), - T.float16(0), - dtype="float16", - ) - with T.launch_thread(tx, 32): - Apad_shared[ - (((((ty * 3072) + (tz * 1536)) + (ax2 * 512)) + tx) + 32) - ] = T.if_then_else( - ( + - 61440 + ), + ], + T.float16(0), + dtype="float16", + ) + with T.launch_thread(tx, 32): + Apad_shared[ + (((((ty * 3072) + (tz * 1536)) + (ax2 * 512)) + tx) + 32) + ] = T.if_then_else( ( ( - (1 <= (T.floordiv(bz, 14) + kh)) - and ((T.floordiv(bz, 14) + kh) < 15) + ( + (1 <= (T.floordiv(bz, 14) + kh)) + and ((T.floordiv(bz, 14) + kh) < 15) + ) + and (1 <= (ax2 + T.floormod(bz, 14))) ) - and (1 <= (ax2 + T.floormod(bz, 14))) - ) - and ((ax2 + T.floormod(bz, 14)) < 15) - ), - T.load( - "float16", - A_1.data, - ( + and ((ax2 + T.floormod(bz, 14)) < 15) + ), + A_1[ ( ( ( ( ( ( - ((bx * 6422528) + (ty * 1605632)) - + (tz * 802816) + ( + ((bx * 6422528) + (ty * 1605632)) + + (tz * 802816) + ) + + (kh * 57344) ) - + (kh * 57344) + + (bz * 4096) ) - + (bz * 4096) + + (ax2 * 4096) ) - + (ax2 * 4096) + + (ic_outer * 512) ) - + (ic_outer * 512) + + tx ) - + tx - ) - - 61408 - ), - ), - T.float16(0), - dtype="float16", - ) - with T.launch_thread(tx, 32): - Apad_shared[ - (((((ty * 3072) + (tz * 1536)) + (ax2 * 512)) + tx) + 64) - ] = T.if_then_else( - ( + - 61408 + ), + ], + T.float16(0), + dtype="float16", + ) + with T.launch_thread(tx, 32): + Apad_shared[ + (((((ty * 3072) + (tz * 1536)) + (ax2 * 512)) + tx) + 64) + ] = T.if_then_else( ( ( - (1 <= (T.floordiv(bz, 14) + kh)) - and ((T.floordiv(bz, 14) + kh) < 15) + ( + (1 <= (T.floordiv(bz, 14) + kh)) + and ((T.floordiv(bz, 14) + kh) < 15) + ) + and (1 <= (ax2 + T.floormod(bz, 14))) ) - and (1 <= (ax2 + T.floormod(bz, 14))) - ) - and ((ax2 + T.floormod(bz, 14)) < 15) - ), - T.load( - "float16", - A_1.data, - ( + and ((ax2 + T.floormod(bz, 14)) < 15) + ), + A_1[ ( ( ( ( ( ( - ((bx * 6422528) + (ty * 1605632)) - + (tz * 802816) + ( + ((bx * 6422528) + (ty * 1605632)) + + (tz * 802816) + ) + + (kh * 57344) ) - + (kh * 57344) + + (bz * 4096) ) - + (bz * 4096) + + (ax2 * 4096) ) - + (ax2 * 4096) + + (ic_outer * 512) ) - + (ic_outer * 512) + + tx ) - + tx + - 61376 + ), + ], + T.float16(0), + dtype="float16", + ) + with T.launch_thread(tx, 32): + Apad_shared[ + (((((ty * 3072) + (tz * 1536)) + (ax2 * 512)) + tx) + 96) + ] = T.if_then_else( + ( + ( + ( + (1 <= (T.floordiv(bz, 14) + kh)) + and ((T.floordiv(bz, 14) + kh) < 15) + ) + and (1 <= (ax2 + T.floormod(bz, 14))) ) - - 61376 + and ((ax2 + T.floormod(bz, 14)) < 15) ), - ), - T.float16(0), - dtype="float16", - ) - with T.launch_thread(tx, 32): - Apad_shared[ - (((((ty * 3072) + (tz * 1536)) + (ax2 * 512)) + tx) + 96) - ] = T.if_then_else( - ( - ( - ( - (1 <= (T.floordiv(bz, 14) + kh)) - and ((T.floordiv(bz, 14) + kh) < 15) - ) - and (1 <= (ax2 + T.floormod(bz, 14))) - ) - and ((ax2 + T.floormod(bz, 14)) < 15) - ), - T.load( - "float16", - A_1.data, - ( + A_1[ ( ( ( ( ( ( - ((bx * 6422528) + (ty * 1605632)) - + (tz * 802816) + ( + ((bx * 6422528) + (ty * 1605632)) + + (tz * 802816) + ) + + (kh * 57344) ) - + (kh * 57344) + + (bz * 4096) ) - + (bz * 4096) + + (ax2 * 4096) ) - + (ax2 * 4096) + + (ic_outer * 512) ) - + (ic_outer * 512) + + tx ) - + tx - ) - - 61344 - ), - ), - T.float16(0), - dtype="float16", - ) - with T.launch_thread(tx, 32): - Apad_shared[ - (((((ty * 3072) + (tz * 1536)) + (ax2 * 512)) + tx) + 128) - ] = T.if_then_else( - ( + - 61344 + ), + ], + T.float16(0), + dtype="float16", + ) + with T.launch_thread(tx, 32): + Apad_shared[ + (((((ty * 3072) + (tz * 1536)) + (ax2 * 512)) + tx) + 128) + ] = T.if_then_else( ( ( - (1 <= (T.floordiv(bz, 14) + kh)) - and ((T.floordiv(bz, 14) + kh) < 15) + ( + (1 <= (T.floordiv(bz, 14) + kh)) + and ((T.floordiv(bz, 14) + kh) < 15) + ) + and (1 <= (ax2 + T.floormod(bz, 14))) ) - and (1 <= (ax2 + T.floormod(bz, 14))) - ) - and ((ax2 + T.floormod(bz, 14)) < 15) - ), - T.load( - "float16", - A_1.data, - ( + and ((ax2 + T.floormod(bz, 14)) < 15) + ), + A_1[ ( ( ( ( ( ( - ((bx * 6422528) + (ty * 1605632)) - + (tz * 802816) + ( + ((bx * 6422528) + (ty * 1605632)) + + (tz * 802816) + ) + + (kh * 57344) ) - + (kh * 57344) + + (bz * 4096) ) - + (bz * 4096) + + (ax2 * 4096) ) - + (ax2 * 4096) + + (ic_outer * 512) ) - + (ic_outer * 512) + + tx ) - + tx - ) - - 61312 - ), - ), - T.float16(0), - dtype="float16", - ) - with T.launch_thread(tx, 32): - Apad_shared[ - (((((ty * 3072) + (tz * 1536)) + (ax2 * 512)) + tx) + 160) - ] = T.if_then_else( - ( + - 61312 + ), + ], + T.float16(0), + dtype="float16", + ) + with T.launch_thread(tx, 32): + Apad_shared[ + (((((ty * 3072) + (tz * 1536)) + (ax2 * 512)) + tx) + 160) + ] = T.if_then_else( ( ( - (1 <= (T.floordiv(bz, 14) + kh)) - and ((T.floordiv(bz, 14) + kh) < 15) + ( + (1 <= (T.floordiv(bz, 14) + kh)) + and ((T.floordiv(bz, 14) + kh) < 15) + ) + and (1 <= (ax2 + T.floormod(bz, 14))) ) - and (1 <= (ax2 + T.floormod(bz, 14))) - ) - and ((ax2 + T.floormod(bz, 14)) < 15) - ), - T.load( - "float16", - A_1.data, - ( + and ((ax2 + T.floormod(bz, 14)) < 15) + ), + A_1[ ( ( ( ( ( ( - ((bx * 6422528) + (ty * 1605632)) - + (tz * 802816) + ( + ((bx * 6422528) + (ty * 1605632)) + + (tz * 802816) + ) + + (kh * 57344) ) - + (kh * 57344) + + (bz * 4096) ) - + (bz * 4096) + + (ax2 * 4096) ) - + (ax2 * 4096) + + (ic_outer * 512) ) - + (ic_outer * 512) + + tx ) - + tx - ) - - 61280 - ), - ), - T.float16(0), - dtype="float16", - ) - with T.launch_thread(tx, 32): - Apad_shared[ - (((((ty * 3072) + (tz * 1536)) + (ax2 * 512)) + tx) + 192) - ] = T.if_then_else( - ( + - 61280 + ), + ], + T.float16(0), + dtype="float16", + ) + with T.launch_thread(tx, 32): + Apad_shared[ + (((((ty * 3072) + (tz * 1536)) + (ax2 * 512)) + tx) + 192) + ] = T.if_then_else( ( ( - (1 <= (T.floordiv(bz, 14) + kh)) - and ((T.floordiv(bz, 14) + kh) < 15) + ( + (1 <= (T.floordiv(bz, 14) + kh)) + and ((T.floordiv(bz, 14) + kh) < 15) + ) + and (1 <= (ax2 + T.floormod(bz, 14))) ) - and (1 <= (ax2 + T.floormod(bz, 14))) - ) - and ((ax2 + T.floormod(bz, 14)) < 15) - ), - T.load( - "float16", - A_1.data, - ( + and ((ax2 + T.floormod(bz, 14)) < 15) + ), + A_1[ ( ( ( ( ( ( - ((bx * 6422528) + (ty * 1605632)) - + (tz * 802816) + ( + ((bx * 6422528) + (ty * 1605632)) + + (tz * 802816) + ) + + (kh * 57344) ) - + (kh * 57344) + + (bz * 4096) ) - + (bz * 4096) + + (ax2 * 4096) ) - + (ax2 * 4096) + + (ic_outer * 512) ) - + (ic_outer * 512) + + tx ) - + tx - ) - - 61248 - ), - ), - T.float16(0), - dtype="float16", - ) - with T.launch_thread(tx, 32): - Apad_shared[ - (((((ty * 3072) + (tz * 1536)) + (ax2 * 512)) + tx) + 224) - ] = T.if_then_else( - ( + - 61248 + ), + ], + T.float16(0), + dtype="float16", + ) + with T.launch_thread(tx, 32): + Apad_shared[ + (((((ty * 3072) + (tz * 1536)) + (ax2 * 512)) + tx) + 224) + ] = T.if_then_else( ( ( - (1 <= (T.floordiv(bz, 14) + kh)) - and ((T.floordiv(bz, 14) + kh) < 15) + ( + (1 <= (T.floordiv(bz, 14) + kh)) + and ((T.floordiv(bz, 14) + kh) < 15) + ) + and (1 <= (ax2 + T.floormod(bz, 14))) ) - and (1 <= (ax2 + T.floormod(bz, 14))) - ) - and ((ax2 + T.floormod(bz, 14)) < 15) - ), - T.load( - "float16", - A_1.data, - ( + and ((ax2 + T.floormod(bz, 14)) < 15) + ), + A_1[ ( ( ( ( ( ( - ((bx * 6422528) + (ty * 1605632)) - + (tz * 802816) + ( + ((bx * 6422528) + (ty * 1605632)) + + (tz * 802816) + ) + + (kh * 57344) ) - + (kh * 57344) + + (bz * 4096) ) - + (bz * 4096) + + (ax2 * 4096) ) - + (ax2 * 4096) + + (ic_outer * 512) ) - + (ic_outer * 512) + + tx ) - + tx - ) - - 61216 - ), - ), - T.float16(0), - dtype="float16", - ) - with T.launch_thread(tx, 32): - Apad_shared[ - (((((ty * 3072) + (tz * 1536)) + (ax2 * 512)) + tx) + 256) - ] = T.if_then_else( - ( + - 61216 + ), + ], + T.float16(0), + dtype="float16", + ) + with T.launch_thread(tx, 32): + Apad_shared[ + (((((ty * 3072) + (tz * 1536)) + (ax2 * 512)) + tx) + 256) + ] = T.if_then_else( ( ( - (1 <= (T.floordiv(bz, 14) + kh)) - and ((T.floordiv(bz, 14) + kh) < 15) + ( + (1 <= (T.floordiv(bz, 14) + kh)) + and ((T.floordiv(bz, 14) + kh) < 15) + ) + and (1 <= (ax2 + T.floormod(bz, 14))) ) - and (1 <= (ax2 + T.floormod(bz, 14))) - ) - and ((ax2 + T.floormod(bz, 14)) < 15) - ), - T.load( - "float16", - A_1.data, - ( + and ((ax2 + T.floormod(bz, 14)) < 15) + ), + A_1[ ( ( ( ( ( ( - ((bx * 6422528) + (ty * 1605632)) - + (tz * 802816) + ( + ((bx * 6422528) + (ty * 1605632)) + + (tz * 802816) + ) + + (kh * 57344) ) - + (kh * 57344) + + (bz * 4096) ) - + (bz * 4096) + + (ax2 * 4096) ) - + (ax2 * 4096) + + (ic_outer * 512) ) - + (ic_outer * 512) + + tx ) - + tx - ) - - 61184 - ), - ), - T.float16(0), - dtype="float16", - ) - with T.launch_thread(tx, 32): - Apad_shared[ - (((((ty * 3072) + (tz * 1536)) + (ax2 * 512)) + tx) + 288) - ] = T.if_then_else( - ( + - 61184 + ), + ], + T.float16(0), + dtype="float16", + ) + with T.launch_thread(tx, 32): + Apad_shared[ + (((((ty * 3072) + (tz * 1536)) + (ax2 * 512)) + tx) + 288) + ] = T.if_then_else( ( ( - (1 <= (T.floordiv(bz, 14) + kh)) - and ((T.floordiv(bz, 14) + kh) < 15) + ( + (1 <= (T.floordiv(bz, 14) + kh)) + and ((T.floordiv(bz, 14) + kh) < 15) + ) + and (1 <= (ax2 + T.floormod(bz, 14))) ) - and (1 <= (ax2 + T.floormod(bz, 14))) - ) - and ((ax2 + T.floormod(bz, 14)) < 15) - ), - T.load( - "float16", - A_1.data, - ( + and ((ax2 + T.floormod(bz, 14)) < 15) + ), + A_1[ ( ( ( ( ( ( - ((bx * 6422528) + (ty * 1605632)) - + (tz * 802816) + ( + ((bx * 6422528) + (ty * 1605632)) + + (tz * 802816) + ) + + (kh * 57344) ) - + (kh * 57344) + + (bz * 4096) ) - + (bz * 4096) + + (ax2 * 4096) ) - + (ax2 * 4096) + + (ic_outer * 512) ) - + (ic_outer * 512) + + tx ) - + tx - ) - - 61152 - ), - ), - T.float16(0), - dtype="float16", - ) - with T.launch_thread(tx, 32): - Apad_shared[ - (((((ty * 3072) + (tz * 1536)) + (ax2 * 512)) + tx) + 320) - ] = T.if_then_else( - ( + - 61152 + ), + ], + T.float16(0), + dtype="float16", + ) + with T.launch_thread(tx, 32): + Apad_shared[ + (((((ty * 3072) + (tz * 1536)) + (ax2 * 512)) + tx) + 320) + ] = T.if_then_else( ( ( - (1 <= (T.floordiv(bz, 14) + kh)) - and ((T.floordiv(bz, 14) + kh) < 15) + ( + (1 <= (T.floordiv(bz, 14) + kh)) + and ((T.floordiv(bz, 14) + kh) < 15) + ) + and (1 <= (ax2 + T.floormod(bz, 14))) ) - and (1 <= (ax2 + T.floormod(bz, 14))) - ) - and ((ax2 + T.floormod(bz, 14)) < 15) - ), - T.load( - "float16", - A_1.data, - ( + and ((ax2 + T.floormod(bz, 14)) < 15) + ), + A_1[ ( ( ( ( ( ( - ((bx * 6422528) + (ty * 1605632)) - + (tz * 802816) + ( + ((bx * 6422528) + (ty * 1605632)) + + (tz * 802816) + ) + + (kh * 57344) ) - + (kh * 57344) + + (bz * 4096) ) - + (bz * 4096) + + (ax2 * 4096) ) - + (ax2 * 4096) + + (ic_outer * 512) ) - + (ic_outer * 512) + + tx ) - + tx - ) - - 61120 - ), - ), - T.float16(0), - dtype="float16", - ) - with T.launch_thread(tx, 32): - Apad_shared[ - (((((ty * 3072) + (tz * 1536)) + (ax2 * 512)) + tx) + 352) - ] = T.if_then_else( - ( + - 61120 + ), + ], + T.float16(0), + dtype="float16", + ) + with T.launch_thread(tx, 32): + Apad_shared[ + (((((ty * 3072) + (tz * 1536)) + (ax2 * 512)) + tx) + 352) + ] = T.if_then_else( ( ( - (1 <= (T.floordiv(bz, 14) + kh)) - and ((T.floordiv(bz, 14) + kh) < 15) + ( + (1 <= (T.floordiv(bz, 14) + kh)) + and ((T.floordiv(bz, 14) + kh) < 15) + ) + and (1 <= (ax2 + T.floormod(bz, 14))) ) - and (1 <= (ax2 + T.floormod(bz, 14))) - ) - and ((ax2 + T.floormod(bz, 14)) < 15) - ), - T.load( - "float16", - A_1.data, - ( + and ((ax2 + T.floormod(bz, 14)) < 15) + ), + A_1[ ( ( ( ( ( ( - ((bx * 6422528) + (ty * 1605632)) - + (tz * 802816) + ( + ((bx * 6422528) + (ty * 1605632)) + + (tz * 802816) + ) + + (kh * 57344) ) - + (kh * 57344) + + (bz * 4096) ) - + (bz * 4096) + + (ax2 * 4096) ) - + (ax2 * 4096) + + (ic_outer * 512) ) - + (ic_outer * 512) + + tx ) - + tx - ) - - 61088 - ), - ), - T.float16(0), - dtype="float16", - ) - with T.launch_thread(tx, 32): - Apad_shared[ - (((((ty * 3072) + (tz * 1536)) + (ax2 * 512)) + tx) + 384) - ] = T.if_then_else( - ( + - 61088 + ), + ], + T.float16(0), + dtype="float16", + ) + with T.launch_thread(tx, 32): + Apad_shared[ + (((((ty * 3072) + (tz * 1536)) + (ax2 * 512)) + tx) + 384) + ] = T.if_then_else( ( ( - (1 <= (T.floordiv(bz, 14) + kh)) - and ((T.floordiv(bz, 14) + kh) < 15) + ( + (1 <= (T.floordiv(bz, 14) + kh)) + and ((T.floordiv(bz, 14) + kh) < 15) + ) + and (1 <= (ax2 + T.floormod(bz, 14))) ) - and (1 <= (ax2 + T.floormod(bz, 14))) - ) - and ((ax2 + T.floormod(bz, 14)) < 15) - ), - T.load( - "float16", - A_1.data, - ( + and ((ax2 + T.floormod(bz, 14)) < 15) + ), + A_1[ ( ( ( ( ( ( - ((bx * 6422528) + (ty * 1605632)) - + (tz * 802816) + ( + ((bx * 6422528) + (ty * 1605632)) + + (tz * 802816) + ) + + (kh * 57344) ) - + (kh * 57344) + + (bz * 4096) ) - + (bz * 4096) + + (ax2 * 4096) ) - + (ax2 * 4096) + + (ic_outer * 512) ) - + (ic_outer * 512) + + tx ) - + tx - ) - - 61056 - ), - ), - T.float16(0), - dtype="float16", - ) - with T.launch_thread(tx, 32): - Apad_shared[ - (((((ty * 3072) + (tz * 1536)) + (ax2 * 512)) + tx) + 416) - ] = T.if_then_else( - ( + - 61056 + ), + ], + T.float16(0), + dtype="float16", + ) + with T.launch_thread(tx, 32): + Apad_shared[ + (((((ty * 3072) + (tz * 1536)) + (ax2 * 512)) + tx) + 416) + ] = T.if_then_else( ( ( - (1 <= (T.floordiv(bz, 14) + kh)) - and ((T.floordiv(bz, 14) + kh) < 15) + ( + (1 <= (T.floordiv(bz, 14) + kh)) + and ((T.floordiv(bz, 14) + kh) < 15) + ) + and (1 <= (ax2 + T.floormod(bz, 14))) ) - and (1 <= (ax2 + T.floormod(bz, 14))) - ) - and ((ax2 + T.floormod(bz, 14)) < 15) - ), - T.load( - "float16", - A_1.data, - ( + and ((ax2 + T.floormod(bz, 14)) < 15) + ), + A_1[ ( ( ( ( ( ( - ((bx * 6422528) + (ty * 1605632)) - + (tz * 802816) + ( + ((bx * 6422528) + (ty * 1605632)) + + (tz * 802816) + ) + + (kh * 57344) ) - + (kh * 57344) + + (bz * 4096) ) - + (bz * 4096) + + (ax2 * 4096) ) - + (ax2 * 4096) + + (ic_outer * 512) ) - + (ic_outer * 512) + + tx ) - + tx + - 61024 + ), + ], + T.float16(0), + dtype="float16", + ) + with T.launch_thread(tx, 32): + Apad_shared[ + (((((ty * 3072) + (tz * 1536)) + (ax2 * 512)) + tx) + 448) + ] = T.if_then_else( + ( + ( + ( + (1 <= (T.floordiv(bz, 14) + kh)) + and ((T.floordiv(bz, 14) + kh) < 15) + ) + and (1 <= (ax2 + T.floormod(bz, 14))) ) - - 61024 + and ((ax2 + T.floormod(bz, 14)) < 15) ), - ), - T.float16(0), - dtype="float16", - ) - with T.launch_thread(tx, 32): + A_1[ + ( + ( + ( + ( + ( + ( + ( + ((bx * 6422528) + (ty * 1605632)) + + (tz * 802816) + ) + + (kh * 57344) + ) + + (bz * 4096) + ) + + (ax2 * 4096) + ) + + (ic_outer * 512) + ) + + tx + ) + - 60992 + ), + ], + T.float16(0), + dtype="float16", + ) + T.launch_thread(tx, 32) Apad_shared[ - (((((ty * 3072) + (tz * 1536)) + (ax2 * 512)) + tx) + 448) + (((((ty * 3072) + (tz * 1536)) + (ax2 * 512)) + tx) + 480) ] = T.if_then_else( ( ( @@ -1734,9 +1634,7 @@ def opt_conv_tensorcore_lower(A: T.handle, W: T.handle, Conv: T.handle) -> None: ) and ((ax2 + T.floormod(bz, 14)) < 15) ), - T.load( - "float16", - A_1.data, + A_1[ ( ( ( @@ -1757,56 +1655,14 @@ def opt_conv_tensorcore_lower(A: T.handle, W: T.handle, Conv: T.handle) -> None: ) + tx ) - - 60992 + - 60960 ), - ), + ], T.float16(0), dtype="float16", ) - T.launch_thread(tx, 32) - Apad_shared[ - (((((ty * 3072) + (tz * 1536)) + (ax2 * 512)) + tx) + 480) - ] = T.if_then_else( - ( - ( - ((1 <= (T.floordiv(bz, 14) + kh)) and ((T.floordiv(bz, 14) + kh) < 15)) - and (1 <= (ax2 + T.floormod(bz, 14))) - ) - and ((ax2 + T.floormod(bz, 14)) < 15) - ), - T.load( - "float16", - A_1.data, - ( - ( - ( - ( - ( - ( - (((bx * 6422528) + (ty * 1605632)) + (tz * 802816)) - + (kh * 57344) - ) - + (bz * 4096) - ) - + (ax2 * 4096) - ) - + (ic_outer * 512) - ) - + tx - ) - - 60960 - ), - ), - T.float16(0), - dtype="float16", - ) - with T.launch_thread(tx, 32): - T.store( - W_shared, - T.ramp((((ty * 512) + (tz * 256)) + (tx * 8)), 1, 8), - T.load( - "float16x8", - W_1.data, + with T.launch_thread(tx, 32): + W_shared[T.ramp((((ty * 512) + (tz * 256)) + (tx * 8)), 1, 8)] = W_1[ T.ramp( ( ( @@ -1820,18 +1676,10 @@ def opt_conv_tensorcore_lower(A: T.handle, W: T.handle, Conv: T.handle) -> None: ), 1, 8, - ), - T.broadcast(True, 8), - ), - T.broadcast(True, 8), - ) - with T.launch_thread(tx, 32): - T.store( - W_shared, - T.ramp(((((ty * 512) + (tz * 256)) + (tx * 8)) + 2048), 1, 8), - T.load( - "float16x8", - W_1.data, + ) + ] + with T.launch_thread(tx, 32): + W_shared[T.ramp(((((ty * 512) + (tz * 256)) + (tx * 8)) + 2048), 1, 8)] = W_1[ T.ramp( ( ( @@ -1848,18 +1696,10 @@ def opt_conv_tensorcore_lower(A: T.handle, W: T.handle, Conv: T.handle) -> None: ), 1, 8, - ), - T.broadcast(True, 8), - ), - T.broadcast(True, 8), - ) - with T.launch_thread(tx, 32): - T.store( - W_shared, - T.ramp(((((ty * 512) + (tz * 256)) + (tx * 8)) + 4096), 1, 8), - T.load( - "float16x8", - W_1.data, + ) + ] + with T.launch_thread(tx, 32): + W_shared[T.ramp(((((ty * 512) + (tz * 256)) + (tx * 8)) + 4096), 1, 8)] = W_1[ T.ramp( ( ( @@ -1876,18 +1716,10 @@ def opt_conv_tensorcore_lower(A: T.handle, W: T.handle, Conv: T.handle) -> None: ), 1, 8, - ), - T.broadcast(True, 8), - ), - T.broadcast(True, 8), - ) - with T.launch_thread(tx, 32): - T.store( - W_shared, - T.ramp(((((ty * 512) + (tz * 256)) + (tx * 8)) + 6144), 1, 8), - T.load( - "float16x8", - W_1.data, + ) + ] + with T.launch_thread(tx, 32): + W_shared[T.ramp(((((ty * 512) + (tz * 256)) + (tx * 8)) + 6144), 1, 8)] = W_1[ T.ramp( ( ( @@ -1904,18 +1736,10 @@ def opt_conv_tensorcore_lower(A: T.handle, W: T.handle, Conv: T.handle) -> None: ), 1, 8, - ), - T.broadcast(True, 8), - ), - T.broadcast(True, 8), - ) - with T.launch_thread(tx, 32): - T.store( - W_shared, - T.ramp(((((ty * 512) + (tz * 256)) + (tx * 8)) + 8192), 1, 8), - T.load( - "float16x8", - W_1.data, + ) + ] + with T.launch_thread(tx, 32): + W_shared[T.ramp(((((ty * 512) + (tz * 256)) + (tx * 8)) + 8192), 1, 8)] = W_1[ T.ramp( ( ( @@ -1932,18 +1756,10 @@ def opt_conv_tensorcore_lower(A: T.handle, W: T.handle, Conv: T.handle) -> None: ), 1, 8, - ), - T.broadcast(True, 8), - ), - T.broadcast(True, 8), - ) - with T.launch_thread(tx, 32): - T.store( - W_shared, - T.ramp(((((ty * 512) + (tz * 256)) + (tx * 8)) + 10240), 1, 8), - T.load( - "float16x8", - W_1.data, + ) + ] + with T.launch_thread(tx, 32): + W_shared[T.ramp(((((ty * 512) + (tz * 256)) + (tx * 8)) + 10240), 1, 8)] = W_1[ T.ramp( ( ( @@ -1960,791 +1776,801 @@ def opt_conv_tensorcore_lower(A: T.handle, W: T.handle, Conv: T.handle) -> None: ), 1, 8, - ), - T.broadcast(True, 8), - ), - T.broadcast(True, 8), - ) - for ic_inner in T.serial(0, 2): - for kw in T.serial(0, 3): - T.evaluate( - T.tvm_load_matrix_sync( - Apad_shared_wmma_matrix_a, - 16, - 16, - 16, - 0, - T.tvm_access_ptr( - T.type_annotation(dtype="float16"), - Apad_shared, - (((ty * 3072) + (kw * 512)) + (ic_inner * 256)), - 256, - 1, + ) + ] + for ic_inner in T.serial(0, 2): + for kw in T.serial(0, 3): + T.evaluate( + T.tvm_load_matrix_sync( + Apad_shared_wmma_matrix_a.data, + 16, + 16, + 16, + 0, + T.tvm_access_ptr( + T.type_annotation(dtype="float16"), + Apad_shared.data, + (((ty * 3072) + (kw * 512)) + (ic_inner * 256)), + 256, + 1, + dtype="handle", + ), + 16, + "row_major", dtype="handle", - ), - 16, - "row_major", - dtype="handle", + ) ) - ) - T.evaluate( - T.tvm_load_matrix_sync( - Apad_shared_wmma_matrix_a, - 16, - 16, - 16, - 1, - T.tvm_access_ptr( - T.type_annotation(dtype="float16"), - Apad_shared, - ((((ty * 3072) + (kw * 512)) + (ic_inner * 256)) + 1536), - 256, + T.evaluate( + T.tvm_load_matrix_sync( + Apad_shared_wmma_matrix_a.data, + 16, + 16, + 16, 1, + T.tvm_access_ptr( + T.type_annotation(dtype="float16"), + Apad_shared.data, + ((((ty * 3072) + (kw * 512)) + (ic_inner * 256)) + 1536), + 256, + 1, + dtype="handle", + ), + 16, + "row_major", dtype="handle", - ), - 16, - "row_major", - dtype="handle", + ) ) - ) - T.evaluate( - T.tvm_load_matrix_sync( - W_shared_wmma_matrix_b, - 16, - 16, - 16, - 0, - T.tvm_access_ptr( - T.type_annotation(dtype="float16"), - W_shared, - (((kw * 4096) + (ic_inner * 2048)) + (tz * 1024)), - 256, - 1, + T.evaluate( + T.tvm_load_matrix_sync( + W_shared_wmma_matrix_b.data, + 16, + 16, + 16, + 0, + T.tvm_access_ptr( + T.type_annotation(dtype="float16"), + W_shared.data, + (((kw * 4096) + (ic_inner * 2048)) + (tz * 1024)), + 256, + 1, + dtype="handle", + ), + 16, + "row_major", dtype="handle", - ), - 16, - "row_major", - dtype="handle", + ) ) - ) - T.evaluate( - T.tvm_load_matrix_sync( - W_shared_wmma_matrix_b, - 16, - 16, - 16, - 1, - T.tvm_access_ptr( - T.type_annotation(dtype="float16"), - W_shared, - ((((kw * 4096) + (ic_inner * 2048)) + (tz * 1024)) + 256), - 256, + T.evaluate( + T.tvm_load_matrix_sync( + W_shared_wmma_matrix_b.data, + 16, + 16, + 16, 1, + T.tvm_access_ptr( + T.type_annotation(dtype="float16"), + W_shared.data, + ((((kw * 4096) + (ic_inner * 2048)) + (tz * 1024)) + 256), + 256, + 1, + dtype="handle", + ), + 16, + "row_major", dtype="handle", - ), - 16, - "row_major", - dtype="handle", + ) ) - ) - T.evaluate( - T.tvm_load_matrix_sync( - W_shared_wmma_matrix_b, - 16, - 16, - 16, - 2, - T.tvm_access_ptr( - T.type_annotation(dtype="float16"), - W_shared, - ((((kw * 4096) + (ic_inner * 2048)) + (tz * 1024)) + 512), - 256, - 1, + T.evaluate( + T.tvm_load_matrix_sync( + W_shared_wmma_matrix_b.data, + 16, + 16, + 16, + 2, + T.tvm_access_ptr( + T.type_annotation(dtype="float16"), + W_shared.data, + ((((kw * 4096) + (ic_inner * 2048)) + (tz * 1024)) + 512), + 256, + 1, + dtype="handle", + ), + 16, + "row_major", dtype="handle", - ), - 16, - "row_major", - dtype="handle", + ) ) - ) - T.evaluate( - T.tvm_load_matrix_sync( - W_shared_wmma_matrix_b, - 16, - 16, - 16, - 3, - T.tvm_access_ptr( - T.type_annotation(dtype="float16"), - W_shared, - ((((kw * 4096) + (ic_inner * 2048)) + (tz * 1024)) + 768), - 256, - 1, + T.evaluate( + T.tvm_load_matrix_sync( + W_shared_wmma_matrix_b.data, + 16, + 16, + 16, + 3, + T.tvm_access_ptr( + T.type_annotation(dtype="float16"), + W_shared.data, + ((((kw * 4096) + (ic_inner * 2048)) + (tz * 1024)) + 768), + 256, + 1, + dtype="handle", + ), + 16, + "row_major", dtype="handle", - ), - 16, - "row_major", - dtype="handle", + ) ) - ) - T.evaluate( - T.tvm_mma_sync( - Conv_wmma_accumulator, - 0, - Apad_shared_wmma_matrix_a, - 0, - W_shared_wmma_matrix_b, - 0, - Conv_wmma_accumulator, - 0, - dtype="handle", + T.evaluate( + T.tvm_mma_sync( + Conv_wmma_accumulator.data, + 0, + Apad_shared_wmma_matrix_a.data, + 0, + W_shared_wmma_matrix_b.data, + 0, + Conv_wmma_accumulator.data, + 0, + dtype="handle", + ) ) - ) - T.evaluate( - T.tvm_mma_sync( - Conv_wmma_accumulator, - 1, - Apad_shared_wmma_matrix_a, - 0, - W_shared_wmma_matrix_b, - 1, - Conv_wmma_accumulator, - 1, - dtype="handle", + T.evaluate( + T.tvm_mma_sync( + Conv_wmma_accumulator.data, + 1, + Apad_shared_wmma_matrix_a.data, + 0, + W_shared_wmma_matrix_b.data, + 1, + Conv_wmma_accumulator.data, + 1, + dtype="handle", + ) ) - ) - T.evaluate( - T.tvm_mma_sync( - Conv_wmma_accumulator, - 2, - Apad_shared_wmma_matrix_a, - 0, - W_shared_wmma_matrix_b, - 2, - Conv_wmma_accumulator, - 2, - dtype="handle", + T.evaluate( + T.tvm_mma_sync( + Conv_wmma_accumulator.data, + 2, + Apad_shared_wmma_matrix_a.data, + 0, + W_shared_wmma_matrix_b.data, + 2, + Conv_wmma_accumulator.data, + 2, + dtype="handle", + ) ) - ) - T.evaluate( - T.tvm_mma_sync( - Conv_wmma_accumulator, - 3, - Apad_shared_wmma_matrix_a, - 0, - W_shared_wmma_matrix_b, - 3, - Conv_wmma_accumulator, - 3, - dtype="handle", + T.evaluate( + T.tvm_mma_sync( + Conv_wmma_accumulator.data, + 3, + Apad_shared_wmma_matrix_a.data, + 0, + W_shared_wmma_matrix_b.data, + 3, + Conv_wmma_accumulator.data, + 3, + dtype="handle", + ) ) - ) - T.evaluate( - T.tvm_mma_sync( - Conv_wmma_accumulator, - 4, - Apad_shared_wmma_matrix_a, - 1, - W_shared_wmma_matrix_b, - 0, - Conv_wmma_accumulator, - 4, - dtype="handle", + T.evaluate( + T.tvm_mma_sync( + Conv_wmma_accumulator.data, + 4, + Apad_shared_wmma_matrix_a.data, + 1, + W_shared_wmma_matrix_b.data, + 0, + Conv_wmma_accumulator.data, + 4, + dtype="handle", + ) ) - ) - T.evaluate( - T.tvm_mma_sync( - Conv_wmma_accumulator, - 5, - Apad_shared_wmma_matrix_a, - 1, - W_shared_wmma_matrix_b, - 1, - Conv_wmma_accumulator, - 5, - dtype="handle", + T.evaluate( + T.tvm_mma_sync( + Conv_wmma_accumulator.data, + 5, + Apad_shared_wmma_matrix_a.data, + 1, + W_shared_wmma_matrix_b.data, + 1, + Conv_wmma_accumulator.data, + 5, + dtype="handle", + ) ) - ) - T.evaluate( - T.tvm_mma_sync( - Conv_wmma_accumulator, - 6, - Apad_shared_wmma_matrix_a, - 1, - W_shared_wmma_matrix_b, - 2, - Conv_wmma_accumulator, - 6, - dtype="handle", + T.evaluate( + T.tvm_mma_sync( + Conv_wmma_accumulator.data, + 6, + Apad_shared_wmma_matrix_a.data, + 1, + W_shared_wmma_matrix_b.data, + 2, + Conv_wmma_accumulator.data, + 6, + dtype="handle", + ) ) - ) - T.evaluate( - T.tvm_mma_sync( - Conv_wmma_accumulator, - 7, - Apad_shared_wmma_matrix_a, - 1, - W_shared_wmma_matrix_b, - 3, - Conv_wmma_accumulator, - 7, - dtype="handle", + T.evaluate( + T.tvm_mma_sync( + Conv_wmma_accumulator.data, + 7, + Apad_shared_wmma_matrix_a.data, + 1, + W_shared_wmma_matrix_b.data, + 3, + Conv_wmma_accumulator.data, + 7, + dtype="handle", + ) ) - ) - T.evaluate( - T.tvm_store_matrix_sync( - Conv_wmma_accumulator, - 16, - 16, - 16, - 0, - T.tvm_access_ptr( - T.type_annotation(dtype="float32"), - Conv_1.data, - (((((bx * 12845056) + (ty * 3211264)) + (bz * 8192)) + (by * 2048)) + (tz * 1024)), - 256, - 2, - dtype="handle", - ), - 16, - "row_major", - dtype="handle", - ) - ) - T.evaluate( - T.tvm_store_matrix_sync( - Conv_wmma_accumulator, - 16, - 16, - 16, - 1, - T.tvm_access_ptr( - T.type_annotation(dtype="float32"), - Conv_1.data, - ( + T.evaluate( + T.tvm_store_matrix_sync( + Conv_wmma_accumulator.data, + 16, + 16, + 16, + 0, + T.tvm_access_ptr( + T.type_annotation(dtype="float32"), + Conv_1.data, ( ((((bx * 12845056) + (ty * 3211264)) + (bz * 8192)) + (by * 2048)) + (tz * 1024) - ) - + 256 + ), + 256, + 2, + dtype="handle", ), - 256, - 2, + 16, + "row_major", dtype="handle", - ), - 16, - "row_major", - dtype="handle", + ) ) - ) - T.evaluate( - T.tvm_store_matrix_sync( - Conv_wmma_accumulator, - 16, - 16, - 16, - 2, - T.tvm_access_ptr( - T.type_annotation(dtype="float32"), - Conv_1.data, - ( + T.evaluate( + T.tvm_store_matrix_sync( + Conv_wmma_accumulator.data, + 16, + 16, + 16, + 1, + T.tvm_access_ptr( + T.type_annotation(dtype="float32"), + Conv_1.data, ( - ((((bx * 12845056) + (ty * 3211264)) + (bz * 8192)) + (by * 2048)) - + (tz * 1024) - ) - + 512 + ( + ((((bx * 12845056) + (ty * 3211264)) + (bz * 8192)) + (by * 2048)) + + (tz * 1024) + ) + + 256 + ), + 256, + 2, + dtype="handle", ), - 256, + 16, + "row_major", + dtype="handle", + ) + ) + T.evaluate( + T.tvm_store_matrix_sync( + Conv_wmma_accumulator.data, + 16, + 16, + 16, 2, + T.tvm_access_ptr( + T.type_annotation(dtype="float32"), + Conv_1.data, + ( + ( + ((((bx * 12845056) + (ty * 3211264)) + (bz * 8192)) + (by * 2048)) + + (tz * 1024) + ) + + 512 + ), + 256, + 2, + dtype="handle", + ), + 16, + "row_major", dtype="handle", - ), - 16, - "row_major", - dtype="handle", + ) ) - ) - T.evaluate( - T.tvm_store_matrix_sync( - Conv_wmma_accumulator, - 16, - 16, - 16, - 3, - T.tvm_access_ptr( - T.type_annotation(dtype="float32"), - Conv_1.data, - ( + T.evaluate( + T.tvm_store_matrix_sync( + Conv_wmma_accumulator.data, + 16, + 16, + 16, + 3, + T.tvm_access_ptr( + T.type_annotation(dtype="float32"), + Conv_1.data, ( - ((((bx * 12845056) + (ty * 3211264)) + (bz * 8192)) + (by * 2048)) - + (tz * 1024) - ) - + 768 + ( + ((((bx * 12845056) + (ty * 3211264)) + (bz * 8192)) + (by * 2048)) + + (tz * 1024) + ) + + 768 + ), + 256, + 2, + dtype="handle", ), - 256, - 2, + 16, + "row_major", dtype="handle", - ), - 16, - "row_major", - dtype="handle", + ) ) - ) - T.evaluate( - T.tvm_store_matrix_sync( - Conv_wmma_accumulator, - 16, - 16, - 16, - 4, - T.tvm_access_ptr( - T.type_annotation(dtype="float32"), - Conv_1.data, - ( + T.evaluate( + T.tvm_store_matrix_sync( + Conv_wmma_accumulator.data, + 16, + 16, + 16, + 4, + T.tvm_access_ptr( + T.type_annotation(dtype="float32"), + Conv_1.data, ( - ((((bx * 12845056) + (ty * 3211264)) + (bz * 8192)) + (by * 2048)) - + (tz * 1024) - ) - + 1605632 + ( + ((((bx * 12845056) + (ty * 3211264)) + (bz * 8192)) + (by * 2048)) + + (tz * 1024) + ) + + 1605632 + ), + 256, + 2, + dtype="handle", ), - 256, - 2, + 16, + "row_major", dtype="handle", - ), - 16, - "row_major", - dtype="handle", + ) ) - ) - T.evaluate( - T.tvm_store_matrix_sync( - Conv_wmma_accumulator, - 16, - 16, - 16, - 5, - T.tvm_access_ptr( - T.type_annotation(dtype="float32"), - Conv_1.data, - ( + T.evaluate( + T.tvm_store_matrix_sync( + Conv_wmma_accumulator.data, + 16, + 16, + 16, + 5, + T.tvm_access_ptr( + T.type_annotation(dtype="float32"), + Conv_1.data, ( - ((((bx * 12845056) + (ty * 3211264)) + (bz * 8192)) + (by * 2048)) - + (tz * 1024) - ) - + 1605888 + ( + ((((bx * 12845056) + (ty * 3211264)) + (bz * 8192)) + (by * 2048)) + + (tz * 1024) + ) + + 1605888 + ), + 256, + 2, + dtype="handle", ), - 256, - 2, + 16, + "row_major", dtype="handle", - ), - 16, - "row_major", - dtype="handle", + ) ) - ) - T.evaluate( - T.tvm_store_matrix_sync( - Conv_wmma_accumulator, - 16, - 16, - 16, - 6, - T.tvm_access_ptr( - T.type_annotation(dtype="float32"), - Conv_1.data, - ( + T.evaluate( + T.tvm_store_matrix_sync( + Conv_wmma_accumulator.data, + 16, + 16, + 16, + 6, + T.tvm_access_ptr( + T.type_annotation(dtype="float32"), + Conv_1.data, ( - ((((bx * 12845056) + (ty * 3211264)) + (bz * 8192)) + (by * 2048)) - + (tz * 1024) - ) - + 1606144 + ( + ((((bx * 12845056) + (ty * 3211264)) + (bz * 8192)) + (by * 2048)) + + (tz * 1024) + ) + + 1606144 + ), + 256, + 2, + dtype="handle", ), - 256, - 2, + 16, + "row_major", dtype="handle", - ), - 16, - "row_major", - dtype="handle", + ) ) - ) - T.evaluate( - T.tvm_store_matrix_sync( - Conv_wmma_accumulator, - 16, - 16, - 16, - 7, - T.tvm_access_ptr( - T.type_annotation(dtype="float32"), - Conv_1.data, - ( + T.evaluate( + T.tvm_store_matrix_sync( + Conv_wmma_accumulator.data, + 16, + 16, + 16, + 7, + T.tvm_access_ptr( + T.type_annotation(dtype="float32"), + Conv_1.data, ( - ((((bx * 12845056) + (ty * 3211264)) + (bz * 8192)) + (by * 2048)) - + (tz * 1024) - ) - + 1606400 + ( + ((((bx * 12845056) + (ty * 3211264)) + (bz * 8192)) + (by * 2048)) + + (tz * 1024) + ) + + 1606400 + ), + 256, + 2, + dtype="handle", ), - 256, - 2, + 16, + "row_major", dtype="handle", - ), - 16, - "row_major", - dtype="handle", + ) ) - ) + return func -def test_opt_conv_tensorcore_lower(): - mod = opt_conv_tensorcore_lower - rt_mod = tvm.script.from_source(mod.script(show_meta=True)) - tvm.ir.assert_structural_equal(mod, rt_mod, True) +def opt_conv_tensorcore_mod_host(): + @T.prim_func + def opt_conv_tensorcore_mod_host( + args: T.handle, + arg_type_ids: T.Buffer[(3,), "int32"], + num_args: T.int32, + out_ret_value: T.handle, + out_ret_tcode: T.handle, + resource_handle: T.handle, + ) -> T.int32: + # function attr dict + T.func_attr( + { + "tir.noalias": True, + "global_symbol": "default_function", + "tir.is_entry_func": True, + "calling_conv": 1, + } + ) + # body + stack_tcode_data: T.Ptr[T.int32] = T.tvm_stack_alloca("arg_tcode", 10, dtype="handle") + stack_tcode = T.buffer_decl([9], "int32", data=stack_tcode_data) + stack_value: T.handle = T.tvm_stack_alloca("arg_value", 10, dtype="handle") + assert num_args == 3, "default_function: num_args should be 3" + arg0: T.handle = T.tvm_struct_get(args, 0, 12, dtype="handle") + arg0_code: T.int32 = arg_type_ids[0] + arg1: T.handle = T.tvm_struct_get(args, 1, 12, dtype="handle") + arg1_code: T.int32 = arg_type_ids[1] + arg2: T.handle = T.tvm_struct_get(args, 2, 12, dtype="handle") + arg2_code: T.int32 = arg_type_ids[2] -@T.prim_func -def opt_conv_tensorcore_mod_host( - args: T.handle, - arg_type_ids: T.handle, - num_args: T.int32, - out_ret_value: T.handle, - out_ret_tcode: T.handle, - resource_handle: T.handle, -) -> T.int32: - # function attr dict - T.func_attr( - { - "tir.noalias": True, - "global_symbol": "default_function", - "tir.is_entry_func": True, - "calling_conv": 1, - } - ) - # body - stack_tcode: T.handle = T.tvm_stack_alloca("arg_tcode", 10, dtype="handle") - stack_value: T.handle = T.tvm_stack_alloca("arg_value", 10, dtype="handle") - assert num_args == 3, "default_function: num_args should be 3" - arg0: T.handle = T.tvm_struct_get(args, 0, 12, dtype="handle") - arg0_code: T.int32 = T.load("int32", arg_type_ids, 0) - arg1: T.handle = T.tvm_struct_get(args, 1, 12, dtype="handle") - arg1_code: T.int32 = T.load("int32", arg_type_ids, 1) - arg2: T.handle = T.tvm_struct_get(args, 2, 12, dtype="handle") - arg2_code: T.int32 = T.load("int32", arg_type_ids, 2) - A: T.handle = T.tvm_struct_get(arg0, 0, 1, dtype="handle") - T.attr(A, "storage_alignment", 128) - arg0_shape: T.handle = T.tvm_struct_get(arg0, 0, 2, dtype="handle") - arg0_strides: T.handle = T.tvm_struct_get(arg0, 0, 3, dtype="handle") - dev_id: T.int32 = T.tvm_struct_get(arg0, 0, 9, dtype="int32") - W: T.handle = T.tvm_struct_get(arg1, 0, 1, dtype="handle") - T.attr(W, "storage_alignment", 128) - arg1_shape: T.handle = T.tvm_struct_get(arg1, 0, 2, dtype="handle") - arg1_strides: T.handle = T.tvm_struct_get(arg1, 0, 3, dtype="handle") - Conv: T.handle = T.tvm_struct_get(arg2, 0, 1, dtype="handle") - T.attr(Conv, "storage_alignment", 128) - arg2_shape: T.handle = T.tvm_struct_get(arg2, 0, 2, dtype="handle") - arg2_strides: T.handle = T.tvm_struct_get(arg2, 0, 3, dtype="handle") - assert (((arg0_code == 3) or (arg0_code == 13)) or (arg0_code == 7)) or ( - arg0_code == 4 - ), "default_function: Expect arg[0] to be pointer" - assert (((arg1_code == 3) or (arg1_code == 13)) or (arg1_code == 7)) or ( - arg1_code == 4 - ), "default_function: Expect arg[1] to be pointer" - assert (((arg2_code == 3) or (arg2_code == 13)) or (arg2_code == 7)) or ( - arg2_code == 4 - ), "default_function: Expect arg[2] to be pointer" - assert 6 == T.tvm_struct_get(arg0, 0, 4, dtype="int32"), "arg0.ndim is expected to equal 6" - assert 6 == T.tvm_struct_get(arg0, 0, 4, dtype="int32"), "arg0.ndim is expected to equal 6" - assert ( - (T.tvm_struct_get(arg0, 0, 5, dtype="uint8") == T.uint8(2)) - and (T.tvm_struct_get(arg0, 0, 6, dtype="uint8") == T.uint8(16)) - ) and ( - T.tvm_struct_get(arg0, 0, 7, dtype="uint16") == T.uint16(1) - ), "arg0.dtype is expected to be float16" - assert 16 == T.cast( - T.load("int64", arg0_shape, 0), "int32" - ), "Argument arg0.shape[0] has an unsatisfied constraint" - assert 14 == T.cast( - T.load("int64", arg0_shape, 1), "int32" - ), "Argument arg0.shape[1] has an unsatisfied constraint" - assert 14 == T.cast( - T.load("int64", arg0_shape, 2), "int32" - ), "Argument arg0.shape[2] has an unsatisfied constraint" - assert 16 == T.cast( - T.load("int64", arg0_shape, 3), "int32" - ), "Argument arg0.shape[3] has an unsatisfied constraint" - assert 16 == T.cast( - T.load("int64", arg0_shape, 4), "int32" - ), "Argument arg0.shape[4] has an unsatisfied constraint" - assert 16 == T.cast( - T.load("int64", arg0_shape, 5), "int32" - ), "Argument arg0.shape[5] has an unsatisfied constraint" - if not (T.isnullptr(arg0_strides, dtype="bool")): + A: T.handle = T.tvm_struct_get(arg0, 0, 1, dtype="handle") + T.attr(A, "storage_alignment", 128) + arg0_shape_data: T.Ptr[T.int64] = T.tvm_struct_get(arg0, 0, 2, dtype="handle") + arg0_shape = T.buffer_decl([6], "int64", data=arg0_shape_data) + arg0_strides_data: T.Ptr[T.int64] = T.tvm_struct_get(arg0, 0, 3, dtype="handle") + arg0_strides = T.buffer_decl([6], "int64", data=arg0_strides_data) + + dev_id: T.int32 = T.tvm_struct_get(arg0, 0, 9, dtype="int32") + + W: T.handle = T.tvm_struct_get(arg1, 0, 1, dtype="handle") + T.attr(W, "storage_alignment", 128) + arg1_shape_data: T.Ptr[T.int64] = T.tvm_struct_get(arg1, 0, 2, dtype="handle") + arg1_shape = T.buffer_decl([6], "int64", data=arg1_shape_data) + arg1_strides_data: T.Ptr[T.int64] = T.tvm_struct_get(arg1, 0, 3, dtype="handle") + arg1_strides = T.buffer_decl([6], "int64", data=arg1_strides_data) + + Conv: T.handle = T.tvm_struct_get(arg2, 0, 1, dtype="handle") + T.attr(Conv, "storage_alignment", 128) + arg2_shape_data: T.Ptr[T.int64] = T.tvm_struct_get(arg2, 0, 2, dtype="handle") + arg2_shape = T.buffer_decl([6], "int64", data=arg2_shape_data) + arg2_strides_data: T.Ptr[T.int64] = T.tvm_struct_get(arg2, 0, 3, dtype="handle") + arg2_strides = T.buffer_decl([6], "int64", data=arg2_strides_data) + + assert (((arg0_code == 3) or (arg0_code == 13)) or (arg0_code == 7)) or ( + arg0_code == 4 + ), "default_function: Expect arg[0] to be pointer" + assert (((arg1_code == 3) or (arg1_code == 13)) or (arg1_code == 7)) or ( + arg1_code == 4 + ), "default_function: Expect arg[1] to be pointer" + assert (((arg2_code == 3) or (arg2_code == 13)) or (arg2_code == 7)) or ( + arg2_code == 4 + ), "default_function: Expect arg[2] to be pointer" + assert 6 == T.tvm_struct_get(arg0, 0, 4, dtype="int32"), "arg0.ndim is expected to equal 6" + assert 6 == T.tvm_struct_get(arg0, 0, 4, dtype="int32"), "arg0.ndim is expected to equal 6" assert ( - ( + (T.tvm_struct_get(arg0, 0, 5, dtype="uint8") == T.uint8(2)) + and (T.tvm_struct_get(arg0, 0, 6, dtype="uint8") == T.uint8(16)) + ) and ( + T.tvm_struct_get(arg0, 0, 7, dtype="uint16") == T.uint16(1) + ), "arg0.dtype is expected to be float16" + assert 16 == T.cast( + arg0_shape[0], "int32" + ), "Argument arg0.shape[0] has an unsatisfied constraint" + assert 14 == T.cast( + arg0_shape[1], "int32" + ), "Argument arg0.shape[1] has an unsatisfied constraint" + assert 14 == T.cast( + arg0_shape[2], "int32" + ), "Argument arg0.shape[2] has an unsatisfied constraint" + assert 16 == T.cast( + arg0_shape[3], "int32" + ), "Argument arg0.shape[3] has an unsatisfied constraint" + assert 16 == T.cast( + arg0_shape[4], "int32" + ), "Argument arg0.shape[4] has an unsatisfied constraint" + assert 16 == T.cast( + arg0_shape[5], "int32" + ), "Argument arg0.shape[5] has an unsatisfied constraint" + if not (T.isnullptr(arg0_strides.data, dtype="bool")): + assert ( ( ( - (1 == T.cast(T.load("int64", arg0_strides, 5), "int32")) - and (16 == T.cast(T.load("int64", arg0_strides, 4), "int32")) + ( + (1 == T.cast(arg0_strides[5], "int32")) + and (16 == T.cast(arg0_strides[4], "int32")) + ) + and (256 == T.cast(arg0_strides[3], "int32")) ) - and (256 == T.cast(T.load("int64", arg0_strides, 3), "int32")) + and (4096 == T.cast(arg0_strides[2], "int32")) ) - and (4096 == T.cast(T.load("int64", arg0_strides, 2), "int32")) - ) - and (57344 == T.cast(T.load("int64", arg0_strides, 1), "int32")) - ) and ( - 802816 == T.cast(T.load("int64", arg0_strides, 0), "int32") - ), "arg0.strides: expected to be compact array" - T.evaluate(0) - assert T.uint64(0) == T.tvm_struct_get( - arg0, 0, 8, dtype="uint64" - ), "Argument arg0.byte_offset has an unsatisfied constraint" - assert 2 == T.tvm_struct_get( - arg0, 0, 10, dtype="int32" - ), "Argument arg0.device_type has an unsatisfied constraint" - assert 6 == T.tvm_struct_get(arg1, 0, 4, dtype="int32"), "arg1.ndim is expected to equal 6" - assert 6 == T.tvm_struct_get(arg1, 0, 4, dtype="int32"), "arg1.ndim is expected to equal 6" - assert ( - (T.tvm_struct_get(arg1, 0, 5, dtype="uint8") == T.uint8(2)) - and (T.tvm_struct_get(arg1, 0, 6, dtype="uint8") == T.uint8(16)) - ) and ( - T.tvm_struct_get(arg1, 0, 7, dtype="uint16") == T.uint16(1) - ), "arg1.dtype is expected to be float16" - assert 3 == T.cast( - T.load("int64", arg1_shape, 0), "int32" - ), "Argument arg1.shape[0] has an unsatisfied constraint" - assert 3 == T.cast( - T.load("int64", arg1_shape, 1), "int32" - ), "Argument arg1.shape[1] has an unsatisfied constraint" - assert 16 == T.cast( - T.load("int64", arg1_shape, 2), "int32" - ), "Argument arg1.shape[2] has an unsatisfied constraint" - assert 32 == T.cast( - T.load("int64", arg1_shape, 3), "int32" - ), "Argument arg1.shape[3] has an unsatisfied constraint" - assert 16 == T.cast( - T.load("int64", arg1_shape, 4), "int32" - ), "Argument arg1.shape[4] has an unsatisfied constraint" - assert 16 == T.cast( - T.load("int64", arg1_shape, 5), "int32" - ), "Argument arg1.shape[5] has an unsatisfied constraint" - if not (T.isnullptr(arg1_strides, dtype="bool")): + and (57344 == T.cast(arg0_strides[1], "int32")) + ) and ( + 802816 == T.cast(arg0_strides[0], "int32") + ), "arg0.strides: expected to be compact array" + T.evaluate(0) + assert T.uint64(0) == T.tvm_struct_get( + arg0, 0, 8, dtype="uint64" + ), "Argument arg0.byte_offset has an unsatisfied constraint" + assert 2 == T.tvm_struct_get( + arg0, 0, 10, dtype="int32" + ), "Argument arg0.device_type has an unsatisfied constraint" + assert 6 == T.tvm_struct_get(arg1, 0, 4, dtype="int32"), "arg1.ndim is expected to equal 6" + assert 6 == T.tvm_struct_get(arg1, 0, 4, dtype="int32"), "arg1.ndim is expected to equal 6" assert ( - ( + (T.tvm_struct_get(arg1, 0, 5, dtype="uint8") == T.uint8(2)) + and (T.tvm_struct_get(arg1, 0, 6, dtype="uint8") == T.uint8(16)) + ) and ( + T.tvm_struct_get(arg1, 0, 7, dtype="uint16") == T.uint16(1) + ), "arg1.dtype is expected to be float16" + assert 3 == T.cast( + arg1_shape[0], "int32" + ), "Argument arg1.shape[0] has an unsatisfied constraint" + assert 3 == T.cast( + arg1_shape[1], "int32" + ), "Argument arg1.shape[1] has an unsatisfied constraint" + assert 16 == T.cast( + arg1_shape[2], "int32" + ), "Argument arg1.shape[2] has an unsatisfied constraint" + assert 32 == T.cast( + arg1_shape[3], "int32" + ), "Argument arg1.shape[3] has an unsatisfied constraint" + assert 16 == T.cast( + arg1_shape[4], "int32" + ), "Argument arg1.shape[4] has an unsatisfied constraint" + assert 16 == T.cast( + arg1_shape[5], "int32" + ), "Argument arg1.shape[5] has an unsatisfied constraint" + if not (T.isnullptr(arg1_strides.data, dtype="bool")): + assert ( ( ( - (1 == T.cast(T.load("int64", arg1_strides, 5), "int32")) - and (16 == T.cast(T.load("int64", arg1_strides, 4), "int32")) + ( + (1 == T.cast(arg1_strides[5], "int32")) + and (16 == T.cast(arg1_strides[4], "int32")) + ) + and (256 == T.cast(arg1_strides[3], "int32")) ) - and (256 == T.cast(T.load("int64", arg1_strides, 3), "int32")) + and (8192 == T.cast(arg1_strides[2], "int32")) ) - and (8192 == T.cast(T.load("int64", arg1_strides, 2), "int32")) - ) - and (131072 == T.cast(T.load("int64", arg1_strides, 1), "int32")) - ) and ( - 393216 == T.cast(T.load("int64", arg1_strides, 0), "int32") - ), "arg1.strides: expected to be compact array" - T.evaluate(0) - assert T.uint64(0) == T.tvm_struct_get( - arg1, 0, 8, dtype="uint64" - ), "Argument arg1.byte_offset has an unsatisfied constraint" - assert 2 == T.tvm_struct_get( - arg1, 0, 10, dtype="int32" - ), "Argument arg1.device_type has an unsatisfied constraint" - assert dev_id == T.tvm_struct_get( - arg1, 0, 9, dtype="int32" - ), "Argument arg1.device_id has an unsatisfied constraint" - assert 6 == T.tvm_struct_get(arg2, 0, 4, dtype="int32"), "arg2.ndim is expected to equal 6" - assert 6 == T.tvm_struct_get(arg2, 0, 4, dtype="int32"), "arg2.ndim is expected to equal 6" - assert ( - (T.tvm_struct_get(arg2, 0, 5, dtype="uint8") == T.uint8(2)) - and (T.tvm_struct_get(arg2, 0, 6, dtype="uint8") == T.uint8(32)) - ) and ( - T.tvm_struct_get(arg2, 0, 7, dtype="uint16") == T.uint16(1) - ), "arg2.dtype is expected to be float32" - assert 16 == T.cast( - T.load("int64", arg2_shape, 0), "int32" - ), "Argument arg2.shape[0] has an unsatisfied constraint" - assert 14 == T.cast( - T.load("int64", arg2_shape, 1), "int32" - ), "Argument arg2.shape[1] has an unsatisfied constraint" - assert 14 == T.cast( - T.load("int64", arg2_shape, 2), "int32" - ), "Argument arg2.shape[2] has an unsatisfied constraint" - assert 32 == T.cast( - T.load("int64", arg2_shape, 3), "int32" - ), "Argument arg2.shape[3] has an unsatisfied constraint" - assert 16 == T.cast( - T.load("int64", arg2_shape, 4), "int32" - ), "Argument arg2.shape[4] has an unsatisfied constraint" - assert 16 == T.cast( - T.load("int64", arg2_shape, 5), "int32" - ), "Argument arg2.shape[5] has an unsatisfied constraint" - if not (T.isnullptr(arg2_strides, dtype="bool")): + and (131072 == T.cast(arg1_strides[1], "int32")) + ) and ( + 393216 == T.cast(arg1_strides[0], "int32") + ), "arg1.strides: expected to be compact array" + T.evaluate(0) + assert T.uint64(0) == T.tvm_struct_get( + arg1, 0, 8, dtype="uint64" + ), "Argument arg1.byte_offset has an unsatisfied constraint" + assert 2 == T.tvm_struct_get( + arg1, 0, 10, dtype="int32" + ), "Argument arg1.device_type has an unsatisfied constraint" + assert dev_id == T.tvm_struct_get( + arg1, 0, 9, dtype="int32" + ), "Argument arg1.device_id has an unsatisfied constraint" + assert 6 == T.tvm_struct_get(arg2, 0, 4, dtype="int32"), "arg2.ndim is expected to equal 6" + assert 6 == T.tvm_struct_get(arg2, 0, 4, dtype="int32"), "arg2.ndim is expected to equal 6" assert ( - ( + (T.tvm_struct_get(arg2, 0, 5, dtype="uint8") == T.uint8(2)) + and (T.tvm_struct_get(arg2, 0, 6, dtype="uint8") == T.uint8(32)) + ) and ( + T.tvm_struct_get(arg2, 0, 7, dtype="uint16") == T.uint16(1) + ), "arg2.dtype is expected to be float32" + assert 16 == T.cast( + arg2_shape[0], "int32" + ), "Argument arg2.shape[0] has an unsatisfied constraint" + assert 14 == T.cast( + arg2_shape[1], "int32" + ), "Argument arg2.shape[1] has an unsatisfied constraint" + assert 14 == T.cast( + arg2_shape[2], "int32" + ), "Argument arg2.shape[2] has an unsatisfied constraint" + assert 32 == T.cast( + arg2_shape[3], "int32" + ), "Argument arg2.shape[3] has an unsatisfied constraint" + assert 16 == T.cast( + arg2_shape[4], "int32" + ), "Argument arg2.shape[4] has an unsatisfied constraint" + assert 16 == T.cast( + arg2_shape[5], "int32" + ), "Argument arg2.shape[5] has an unsatisfied constraint" + if not (T.isnullptr(arg2_strides.data, dtype="bool")): + assert ( ( ( - (1 == T.cast(T.load("int64", arg2_strides, 5), "int32")) - and (16 == T.cast(T.load("int64", arg2_strides, 4), "int32")) + ( + (1 == T.cast(arg2_strides[5], "int32")) + and (16 == T.cast(arg2_strides[4], "int32")) + ) + and (256 == T.cast(arg2_strides[3], "int32")) ) - and (256 == T.cast(T.load("int64", arg2_strides, 3), "int32")) + and (8192 == T.cast(arg2_strides[2], "int32")) ) - and (8192 == T.cast(T.load("int64", arg2_strides, 2), "int32")) + and (114688 == T.cast(arg2_strides[1], "int32")) + ) and ( + 1605632 == T.cast(arg2_strides[0], "int32") + ), "arg2.strides: expected to be compact array" + T.evaluate(0) + assert T.uint64(0) == T.tvm_struct_get( + arg2, 0, 8, dtype="uint64" + ), "Argument arg2.byte_offset has an unsatisfied constraint" + assert 2 == T.tvm_struct_get( + arg2, 0, 10, dtype="int32" + ), "Argument arg2.device_type has an unsatisfied constraint" + assert dev_id == T.tvm_struct_get( + arg2, 0, 9, dtype="int32" + ), "Argument arg2.device_id has an unsatisfied constraint" + T.evaluate(T.tvm_struct_set(stack_value, 0, 12, T.cast(2, "int64"), dtype="int32")) + stack_tcode[0] = 0 + T.evaluate(T.tvm_struct_set(stack_value, 1, 12, T.cast(dev_id, "int64"), dtype="int32")) + stack_tcode[1] = 0 + T.evaluate( + T.tvm_call_packed_lowered( + "__tvm_set_device", stack_value, stack_tcode.data, 0, 2, dtype="int32" + ) + ) + T.attr(0, "compute_scope", "default_function_compute_") + T.evaluate(T.tvm_struct_set(stack_value, 0, 12, A, dtype="int32")) + stack_tcode[0] = 3 + T.evaluate(T.tvm_struct_set(stack_value, 1, 12, W, dtype="int32")) + stack_tcode[1] = 3 + T.evaluate(T.tvm_struct_set(stack_value, 2, 12, Conv, dtype="int32")) + stack_tcode[2] = 3 + T.evaluate(T.tvm_struct_set(stack_value, 3, 12, T.cast(196, "int64"), dtype="int32")) + stack_tcode[3] = 0 + T.evaluate(T.tvm_struct_set(stack_value, 4, 12, T.cast(2, "int64"), dtype="int32")) + stack_tcode[4] = 0 + T.evaluate(T.tvm_struct_set(stack_value, 5, 12, T.cast(4, "int64"), dtype="int32")) + stack_tcode[5] = 0 + T.evaluate(T.tvm_struct_set(stack_value, 6, 12, T.cast(4, "int64"), dtype="int32")) + stack_tcode[6] = 0 + T.evaluate(T.tvm_struct_set(stack_value, 7, 12, T.cast(2, "int64"), dtype="int32")) + stack_tcode[7] = 0 + T.evaluate(T.tvm_struct_set(stack_value, 8, 12, T.cast(32, "int64"), dtype="int32")) + stack_tcode[8] = 0 + T.evaluate( + T.tvm_call_packed_lowered( + "default_function_kernel0", stack_value, stack_tcode.data, 0, 9, dtype="int32" ) - and (114688 == T.cast(T.load("int64", arg2_strides, 1), "int32")) - ) and ( - 1605632 == T.cast(T.load("int64", arg2_strides, 0), "int32") - ), "arg2.strides: expected to be compact array" - T.evaluate(0) - assert T.uint64(0) == T.tvm_struct_get( - arg2, 0, 8, dtype="uint64" - ), "Argument arg2.byte_offset has an unsatisfied constraint" - assert 2 == T.tvm_struct_get( - arg2, 0, 10, dtype="int32" - ), "Argument arg2.device_type has an unsatisfied constraint" - assert dev_id == T.tvm_struct_get( - arg2, 0, 9, dtype="int32" - ), "Argument arg2.device_id has an unsatisfied constraint" - T.evaluate(T.tvm_struct_set(stack_value, 0, 12, T.cast(2, "int64"), dtype="int32")) - stack_tcode[0] = 0 - T.evaluate(T.tvm_struct_set(stack_value, 1, 12, T.cast(dev_id, "int64"), dtype="int32")) - stack_tcode[1] = 0 - T.evaluate( - T.tvm_call_packed_lowered("__tvm_set_device", stack_value, stack_tcode, 0, 2, dtype="int32") - ) - T.attr(0, "compute_scope", "default_function_compute_") - T.evaluate(T.tvm_struct_set(stack_value, 0, 12, A, dtype="int32")) - stack_tcode[0] = 3 - T.evaluate(T.tvm_struct_set(stack_value, 1, 12, W, dtype="int32")) - stack_tcode[1] = 3 - T.evaluate(T.tvm_struct_set(stack_value, 2, 12, Conv, dtype="int32")) - stack_tcode[2] = 3 - T.evaluate(T.tvm_struct_set(stack_value, 3, 12, T.cast(196, "int64"), dtype="int32")) - stack_tcode[3] = 0 - T.evaluate(T.tvm_struct_set(stack_value, 4, 12, T.cast(2, "int64"), dtype="int32")) - stack_tcode[4] = 0 - T.evaluate(T.tvm_struct_set(stack_value, 5, 12, T.cast(4, "int64"), dtype="int32")) - stack_tcode[5] = 0 - T.evaluate(T.tvm_struct_set(stack_value, 6, 12, T.cast(4, "int64"), dtype="int32")) - stack_tcode[6] = 0 - T.evaluate(T.tvm_struct_set(stack_value, 7, 12, T.cast(2, "int64"), dtype="int32")) - stack_tcode[7] = 0 - T.evaluate(T.tvm_struct_set(stack_value, 8, 12, T.cast(32, "int64"), dtype="int32")) - stack_tcode[8] = 0 - T.evaluate( - T.tvm_call_packed_lowered( - "default_function_kernel0", stack_value, stack_tcode, 0, 9, dtype="int32" ) - ) + return opt_conv_tensorcore_mod_host -def test_opt_conv_tensorcore_mod_host(): - mod = opt_conv_tensorcore_mod_host - rt_mod = tvm.script.from_source(mod.script(show_meta=True)) - tvm.ir.assert_structural_equal(mod, rt_mod, True) +def vthread_func(): + @T.prim_func + def vthread_func(a: T.handle, c: T.handle) -> None: + A = T.match_buffer(a, [256], "float32") + C = T.match_buffer(c, [256], "float32") + + i0 = T.env_thread("blockIdx.x") + i1 = T.env_thread("threadIdx.x") + i2 = T.env_thread("vthread") + + T.launch_thread(i0, 4) + T.launch_thread(i1, 2) + T.launch_thread(i2, 2) + B = T.allocate([16], "float32", "local") + for j in range(16): + B[j] = A[i0 * 64 + i1 * 32 + i2 * 16 + j] + T.float32(1) + for j in range(16): + C[i0 * 64 + i1 * 32 + i2 * 16 + j] = B[j] * T.float32(2) -@T.prim_func -def vthread_func(a: T.handle, c: T.handle) -> None: - A = T.match_buffer(a, (16, 16), "float32") - C = T.match_buffer(c, (16, 16), "float32") - - i0 = T.env_thread("blockIdx.x") - i1 = T.env_thread("threadIdx.x") - i2 = T.env_thread("vthread") - - T.launch_thread(i0, 4) - T.launch_thread(i1, 2) - T.launch_thread(i2, 2) - B = T.allocate([16], "float32", "local") - for j in range(16): - B[j] = T.load("float32", A.data, i0 * 64 + i1 * 32 + i2 * 16 + j) + T.float32(1) - for j in range(16): - C.data[i0 * 64 + i1 * 32 + i2 * 16 + j] = T.load("float32", B, j) * T.float32(2) - - -def test_vthread(): - func = vthread_func - rt_func = tvm.script.from_source(func.script(show_meta=True)) - tvm.ir.assert_structural_equal(func, rt_func, True) + return vthread_func -@T.prim_func -def matmul(a: T.handle, b: T.handle, c: T.handle) -> None: - A = T.match_buffer(a, [128, 128]) - B = T.match_buffer(b, [128, 128]) - C = T.match_buffer(c, [128, 128]) +def matmul(): + @T.prim_func + def matmul(a: T.handle, b: T.handle, c: T.handle) -> None: + A = T.match_buffer(a, [128, 128]) + B = T.match_buffer(b, [128, 128]) + C = T.match_buffer(c, [128, 128]) - for i, j, k in T.grid(128, 128, 128): - with T.block("update"): - vi, vj, vk = T.axis.remap("SSR", [i, j, k]) - with T.init(): + for i, j, k in T.grid(128, 128, 128): + with T.block("update"): + vi, vj, vk = T.axis.remap("SSR", [i, j, k]) + with T.init(): + C[vi, vj] = T.float32(0) + C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vj, vk] + + return matmul + + +def matmul_original(): + @T.prim_func + def matmul_original(a: T.handle, b: T.handle, c: T.handle) -> None: + A = T.match_buffer(a, [128, 128]) + B = T.match_buffer(b, [128, 128]) + C = T.match_buffer(c, [128, 128]) + + for i, j in T.grid(128, 128): + with T.block("init"): + vi, vj = T.axis.remap("SS", [i, j]) C[vi, vj] = T.float32(0) - C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vj, vk] + for k in range(128): + with T.block("update"): + vi, vj, vk = T.axis.remap("SSR", [i, j, k]) + C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vj, vk] -@T.prim_func -def matmul_original(a: T.handle, b: T.handle, c: T.handle) -> None: - A = T.match_buffer(a, [128, 128]) - B = T.match_buffer(b, [128, 128]) - C = T.match_buffer(c, [128, 128]) + return matmul_original - for i, j in T.grid(128, 128): - with T.block("init"): - vi, vj = T.axis.remap("SS", [i, j]) - C[vi, vj] = T.float32(0) - for k in range(128): - with T.block("update"): - vi, vj, vk = T.axis.remap("SSR", [i, j, k]) - C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vj, vk] +def element_wise(): + @T.prim_func + def element_wise(a: T.handle, c: T.handle) -> None: + A = T.match_buffer(a, (128, 128), "float32") + C = T.match_buffer(c, (128, 128), "float32") + B = T.alloc_buffer((128, 128), "float32") + for i, j in T.grid(128, 128): + with T.block("B"): + vi, vj = T.axis.remap("SS", [i, j]) + B[vi, vj] = A[vi, vj] * T.float32(2) + for i, j in T.grid(128, 128): + with T.block("C"): + vi, vj = T.axis.remap("SS", [i, j]) + C[vi, vj] = B[vi, vj] + T.float32(1) -@T.prim_func -def element_wise(a: T.handle, c: T.handle) -> None: - A = T.match_buffer(a, (128, 128), "float32") - C = T.match_buffer(c, (128, 128), "float32") - B = T.alloc_buffer((128, 128), "float32") + return element_wise - for i, j in T.grid(128, 128): - with T.block("B"): - vi, vj = T.axis.remap("SS", [i, j]) - B[vi, vj] = A[vi, vj] * T.float32(2) - for i, j in T.grid(128, 128): - with T.block("C"): - vi, vj = T.axis.remap("SS", [i, j]) - C[vi, vj] = B[vi, vj] + T.float32(1) +def predicate(): + @T.prim_func + def predicate(b: T.handle, c: T.handle) -> None: + B = T.match_buffer(b, (16, 16), "float32") + C = T.match_buffer(c, (16, 16), "float32") -@T.prim_func -def predicate(b: T.handle, c: T.handle) -> None: - B = T.match_buffer(b, (16, 16), "float32") - C = T.match_buffer(c, (16, 16), "float32") + for i, jo, ji in T.grid(16, 4, 5): + with T.block("update"): + vi = T.axis.S(16, i) + vj = T.axis.S(16, jo * 4 + ji) + T.where(jo * 4 + ji < 16) + C[vi, vj] = B[vi, vj] + T.float32(1) - for i, jo, ji in T.grid(16, 4, 5): - with T.block("update"): - vi = T.axis.S(16, i) - vj = T.axis.S(16, jo * 4 + ji) - T.where(jo * 4 + ji < 16) - C[vi, vj] = B[vi, vj] + T.float32(1) + return predicate def test_module_define(): - func1 = tvm.ir.IRModule({"matmul": matmul})["matmul"] - func2 = tvm.ir.IRModule({"element_wise": element_wise})["element_wise"] - func3 = tvm.ir.IRModule({"predicate": predicate})["predicate"] + func1 = tvm.ir.IRModule({"matmul": matmul()})["matmul"] + func2 = tvm.ir.IRModule({"element_wise": element_wise()})["element_wise"] + func3 = tvm.ir.IRModule({"predicate": predicate()})["predicate"] mod1 = tvm.ir.IRModule({"func1": func1, "func2": func2, "func3": func3}) - mod2 = tvm.ir.IRModule({"func1": matmul, "func2": element_wise, "func3": predicate}) + mod2 = tvm.ir.IRModule({"func1": matmul(), "func2": element_wise(), "func3": predicate()}) tvm.ir.assert_structural_equal(mod1, mod2) -def test_matmul(): - func = matmul - rt_func = tvm.script.from_source(func.script(show_meta=True)) - tvm.ir.assert_structural_equal(func, rt_func) - - def test_matmul_original(): - func = matmul_original + func = matmul_original() rt_func = tvm.script.from_source(func.script(show_meta=True)) tvm.ir.assert_structural_equal(func, rt_func) @@ -2758,7 +2584,7 @@ def test_matmul_original(): def test_element_wise(): - func = element_wise + func = element_wise() rt_func = tvm.script.from_source(func.script(show_meta=True)) tvm.ir.assert_structural_equal(func, rt_func) @@ -2774,7 +2600,7 @@ def test_element_wise(): def test_predicate(): - func = predicate + func = predicate() rt_func = tvm.script.from_source(func.script(show_meta=True)) tvm.ir.assert_structural_equal(func, rt_func) @@ -2785,20 +2611,23 @@ def test_predicate(): assert isinstance(rt_func.body.block.body.body.body.body.block, tir.stmt.Block) -@T.prim_func -def for_thread_binding(a: T.handle, b: T.handle) -> None: - A = T.match_buffer(a, (16, 16), "float32") - B = T.match_buffer(b, (16, 16), "float32") +def for_thread_binding(): + @T.prim_func + def for_thread_binding(a: T.handle, b: T.handle) -> None: + A = T.match_buffer(a, (16, 16), "float32") + B = T.match_buffer(b, (16, 16), "float32") + + for i in T.thread_binding(0, 16, thread="threadIdx.x"): + for j in T.thread_binding( + 0, 16, thread="threadIdx.y", annotations={"attr_key": "attr_value"} + ): + A[i, j] = B[i, j] + T.float32(1) - for i in T.thread_binding(0, 16, thread="threadIdx.x"): - for j in T.thread_binding( - 0, 16, thread="threadIdx.y", annotations={"attr_key": "attr_value"} - ): - A[i, j] = B[i, j] + T.float32(1) + return for_thread_binding def test_for_thread_binding(): - func = for_thread_binding + func = for_thread_binding() rt_func = tvm.script.from_source(func.script(show_meta=True)) tvm.ir.assert_structural_equal(func, rt_func) @@ -2811,25 +2640,28 @@ def test_for_thread_binding(): assert rt_func.body.body.annotations["attr_key"] == "attr_value" -@T.prim_func -def match_buffer_region(a: T.handle, b: T.handle) -> None: - A = T.match_buffer(a, (16, 16, 16), "float32") - B = T.match_buffer(b, (1), "float32") - - for i, j in T.grid(16, 4): - with T.block(): - vi, vj = T.axis.remap("SS", [i, j]) - C = T.match_buffer(A[0:16, vi, vj * 4 : vj * 4 + 4], (16, 1, 4)) - for ii in range(4): - with T.block(): - vii = T.axis.S(4, ii) - D = T.match_buffer(C[vii * 4 : vii * 4 + 4, 0, 0:4], (4, 1, 4)) - for i, j in T.grid(4, 4): - B[0] += D[i, 0, j] +def match_buffer_region(): + @T.prim_func + def match_buffer_region(a: T.handle, b: T.handle) -> None: + A = T.match_buffer(a, (16, 16, 16), "float32") + B = T.match_buffer(b, (1), "float32") + + for i, j in T.grid(16, 4): + with T.block(): + vi, vj = T.axis.remap("SS", [i, j]) + C = T.match_buffer(A[0:16, vi, vj * 4 : vj * 4 + 4], (16, 1, 4)) + for ii in range(4): + with T.block(): + vii = T.axis.S(4, ii) + D = T.match_buffer(C[vii * 4 : vii * 4 + 4, 0, 0:4], (4, 1, 4)) + for i, j in T.grid(4, 4): + B[0] += D[i, 0, j] + + return match_buffer_region def test_match_buffer_region(): - func = match_buffer_region + func = match_buffer_region() rt_func = tvm.script.from_source(func.script(show_meta=True)) tvm.ir.assert_structural_equal(func, rt_func) @@ -2852,26 +2684,29 @@ def test_match_buffer_region(): tvm.ir.assert_structural_equal(buffer_D.shape, [4, 1, 4]) -@T.prim_func -def block_elements(a: T.handle, b: T.handle) -> None: - A = T.match_buffer(a, (16, 16), "float32") - B = T.match_buffer(b, (1, 1), "float32") - - with T.block("update"): - vi = T.axis.S(1, 0) - T.where(True) - T.reads(A[0:16, 0:16]) - T.writes(B[0, 0]) - T.block_attr({"attr_key": "attr_value"}) - C = T.alloc_buffer((4, 4), dtype="float32") - D = T.match_buffer(A[0:4, 0], (4, 1)) - with T.init(): - B[0, 0] = T.float32(0) - B[0, 0] = A[0, 0] + B[0, 0] + C[1, 1] + D[2] +def block_elements(): + @T.prim_func + def block_elements(a: T.handle, b: T.handle) -> None: + A = T.match_buffer(a, (16, 16), "float32") + B = T.match_buffer(b, (1, 1), "float32") + + with T.block("update"): + vi = T.axis.S(1, 0) + T.where(True) + T.reads(A[0:16, 0:16]) + T.writes(B[0, 0]) + T.block_attr({"attr_key": "attr_value"}) + C = T.alloc_buffer((4, 4), dtype="float32") + D = T.match_buffer(A[0:4, 0], (4, 1)) + with T.init(): + B[0, 0] = T.float32(0) + B[0, 0] = A[0, 0] + B[0, 0] + C[1, 1] + D[2, 0] + + return block_elements def test_block_elements(): - func = block_elements + func = block_elements() rt_func = tvm.script.from_source(func.script(show_meta=True)) tvm.ir.assert_structural_equal(func, rt_func) @@ -2885,26 +2720,29 @@ def test_block_elements(): assert block.annotations["attr_key"] == "attr_value" -@T.prim_func -def opaque_block(a: T.handle, b: T.handle) -> None: - A = T.match_buffer(a, (16, 16), "float32") - B = T.match_buffer(b, (16, 16), "float32") +def opaque_block(): + @T.prim_func + def opaque_block(a: T.handle, b: T.handle) -> None: + A = T.match_buffer(a, (16, 16), "float32") + B = T.match_buffer(b, (16, 16), "float32") - for i in range(16): - for j in range(16): - with T.block(): - T.reads([]) - T.writes(A[i, j]) - A[i, j] = T.float32(0) - with T.block(): - T.reads([A[i, 0:16]]) - T.writes([B[i, 0:16]]) + for i in range(16): for j in range(16): - B[i, j] = A[i, j] + with T.block(): + T.reads([]) + T.writes(A[i, j]) + A[i, j] = T.float32(0) + with T.block(): + T.reads([A[i, 0:16]]) + T.writes([B[i, 0:16]]) + for j in range(16): + B[i, j] = A[i, j] + + return opaque_block def test_opaque_block(): - func = opaque_block + func = opaque_block() rt_func = tvm.script.from_source(func.script(show_meta=True)) tvm.ir.assert_structural_equal(func, rt_func) @@ -2920,194 +2758,170 @@ def test_opaque_block(): assert len(root_block.body.body[1].block.iter_vars) == 0 -@tvm.script.ir_module -class Module4: - # There is an ongoing (python)dict->(c++)Map->(python)dict issue which potentially - # changes order of the items in dict after roundtrip due to map not support order - # of insertion while dict does. Hence func 'def A(a: T.handle, c: T.handle) -> None' - # is commented - # - # test: - # d = {"B": 1, "A": 2} - # m = tvm.runtime.convert(d) - # assert d.keys() == m.keys(), f"Order changed from {list(d.keys())} to {list(m.keys())}" +def module_const(): + @tvm.script.ir_module + class Module4: + # There is an ongoing (python)dict->(c++)Map->(python)dict issue which potentially + # changes order of the items in dict after roundtrip due to map not support order + # of insertion while dict does. Hence func 'def A(a: T.handle, c: T.handle) -> None' + # is commented + # + # test: + # d = {"B": 1, "A": 2} + # m = tvm.runtime.convert(d) + # assert d.keys() == m.keys(), f"Order changed from {list(d.keys())} to {list(m.keys())}" - """ - @T.prim_func - def A(a: T.handle, c: T.handle) -> None: - A = T.match_buffer(a, (10), "int32") - C = T.match_buffer(c, (10), "int32") - B = T.alloc_buffer((10), "int32") + """ + @T.prim_func + def A(a: T.handle, c: T.handle) -> None: + A = T.match_buffer(a, (10), "int32") + C = T.match_buffer(c, (10), "int32") + B = T.alloc_buffer((10), "int32") - K1 = T.allocate_const([1, 1, 1, 1, 1, 1, 1, 1, 1, 1], "int32", [10]) - for x in T.serial(0, 10): - B[x] = A[x] + T.load("int32", K1, x) + K1 = T.allocate_const([1, 1, 1, 1, 1, 1, 1, 1, 1, 1], "int32", [10]) + for x in T.serial(0, 10): + B[x] = A[x] + T.load("int32", K1, x) - for x in T.serial(0, 10): - C[x] = B[x] - """ + for x in T.serial(0, 10): + C[x] = B[x] + """ + + @T.prim_func + def B(a: T.handle, c: T.handle) -> None: + A = T.match_buffer(a, (10), "int32") + C = T.match_buffer(c, (10), "int32") + B = T.alloc_buffer((10), "int32") + + K1 = T.allocate_const([1, 1, 1, 1, 1, 1, 1, 1, 1, 1], "int32", [10]) + for x in T.serial(0, 10): + B[x] = A[x] + K1[x] + + K2 = T.allocate_const([1, 1, 1, 1, 1, 1, 1, 1, 1, 1], "int32", [10]) + for x in T.serial(0, 10): + B[x] = B[x] + K2[x] + + for x in T.serial(0, 10): + C[x] = B[x] + return Module4 + + +def constant(): @T.prim_func - def B(a: T.handle, c: T.handle) -> None: + def constant(a: T.handle, c: T.handle) -> None: A = T.match_buffer(a, (10), "int32") C = T.match_buffer(c, (10), "int32") B = T.alloc_buffer((10), "int32") - - K1 = T.allocate_const([1, 1, 1, 1, 1, 1, 1, 1, 1, 1], "int32", [10]) - for x in T.serial(0, 10): - B[x] = A[x] + T.load("int32", K1, x) - - K2 = T.allocate_const([1, 1, 1, 1, 1, 1, 1, 1, 1, 1], "int32", [10]) + K = T.allocate_const([1, 1, 1, 1, 1, 1, 1, 1, 1, 1], "int32", [10]) for x in T.serial(0, 10): - B[x] = B[x] + T.load("int32", K2, x) + B[x] = A[x] + K[x] for x in T.serial(0, 10): C[x] = B[x] + return constant -def test_module_const(): - func = Module4 - rt_func = tvm.script.from_source(func.script(show_meta=True)) - tvm.ir.assert_structural_equal(func, rt_func) - - -@T.prim_func -def constant(a: T.handle, c: T.handle) -> None: - A = T.match_buffer(a, (10), "int32") - C = T.match_buffer(c, (10), "int32") - B = T.alloc_buffer((10), "int32") - K = T.allocate_const([1, 1, 1, 1, 1, 1, 1, 1, 1, 1], "int32", [10]) - for x in T.serial(0, 10): - B[x] = A[x] + T.load("int32", K, x) - - for x in T.serial(0, 10): - C[x] = B[x] - - -def test_const(): - func = constant - rt_func = tvm.script.from_source(func.script(show_meta=True)) - tvm.ir.assert_structural_equal(func, rt_func) - - -@T.prim_func -def rank0(a: T.handle) -> None: - A = T.match_buffer(a, (), "float32") - B = T.alloc_buffer((), "float32") - A[()] = 2 - B[()] = A[()] - - -def test_rank0_buffers(): - func = rank0 - rt_func = tvm.script.from_source(func.script(show_meta=True)) - tvm.ir.assert_structural_equal(func, rt_func) - - -@T.prim_func -def rank0_block(a: T.handle) -> None: - A = T.match_buffer(a, (), "float32") - B = T.alloc_buffer((), "float32") - T.store(B.data, 0, T.load("float32", A.data, 0)) - - with T.block("update") as []: - T.reads([A[()]]) - T.writes([B[()]]) - for i in range(1): - B[()] = A[()] +def rank0(): + @T.prim_func + def rank0(a: T.handle) -> None: + A = T.match_buffer(a, (), "float32") + B = T.alloc_buffer((), "float32") + A[()] = 2 + B[()] = A[()] -def test_rank0_blocks(): - func = rank0_block - rt_func = tvm.script.from_source(func.script(show_meta=True)) - tvm.ir.assert_structural_equal(func, rt_func) + return rank0 -@T.prim_func -def select(a: T.handle) -> None: - A = T.match_buffer(a, (), "float32") - A[()] = T.Select(True, 1, 2) +def rank0_block(): + @T.prim_func + def rank0_block(a: T.handle) -> None: + A = T.match_buffer(a, (), "float32") + B = T.alloc_buffer((), "float32") + B[()] = A[()] + with T.block("update") as []: + T.reads([A[()]]) + T.writes([B[()]]) + for i in range(1): + B[()] = A[()] -def test_select(): - func = select - rt_func = tvm.script.from_source(func.script(show_meta=True)) - tvm.ir.assert_structural_equal(func, rt_func) + return rank0_block -@T.prim_func -def minmax(a: T.handle) -> None: - A = T.match_buffer(a, (), "float32") - A[()] = T.min(1, 2) - A[()] = T.max(1, 2) +def select(): + @T.prim_func + def select(a: T.handle) -> None: + A = T.match_buffer(a, (), "float32") + A[()] = T.Select(True, 1, 2) + return select -def test_minmax(): - func = minmax - rt_func = tvm.script.from_source(func.script(show_meta=True)) - tvm.ir.assert_structural_equal(func, rt_func) +def minmax(): + @T.prim_func + def minmax(a: T.handle) -> None: + A = T.match_buffer(a, (), "float32") + A[()] = T.min(1, 2) + A[()] = T.max(1, 2) -@T.prim_func -def abs(a: T.handle) -> None: - A = T.match_buffer(a, (128, 128), "float32") + return minmax - for i, j in T.grid(128, 128): - with T.block("A"): - vi, vj = T.axis.remap("SS", [i, j]) - A[vi, vj] = T.abs(A[vi, vj]) +def abs(): + @T.prim_func + def abs(a: T.handle) -> None: + A = T.match_buffer(a, (128, 128), "float32") -def test_abs(): - func = abs - rt_func = tvm.script.from_source(func.script(show_meta=True)) - tvm.ir.assert_structural_equal(func, rt_func) + for i, j in T.grid(128, 128): + with T.block("A"): + vi, vj = T.axis.remap("SS", [i, j]) + A[vi, vj] = T.abs(A[vi, vj]) + return abs -@T.prim_func -def constant_folding(a: T.handle) -> None: - A = T.match_buffer(a, (), "float32") - A[()] = T.min(2.2, 5.2) - A[()] = T.max(T.float32(2.2), T.float32(T.float32(5.2))) - A[()] = T.min(2.2, 5.0) +def constant_folding(): + @T.prim_func + def constant_folding(a: T.handle) -> None: + A = T.match_buffer(a, (), "float32") + A[()] = T.min(2.2, 5.2) + A[()] = T.max(T.float32(2.2), T.float32(T.float32(5.2))) + A[()] = T.min(2.2, 5.0) -def test_constant_folding(): - func = constant_folding - rt_func = tvm.script.from_source(func.script(show_meta=True)) - tvm.ir.assert_structural_equal(func, rt_func) + return constant_folding -@T.prim_func -def simplify_bracket() -> None: - a = T.var("int32") - b = T.var("int32") - c = T.var("int32") - d = T.var("int32") - T.evaluate(a + b * (c + d)) +def simplify_bracket(): + @T.prim_func + def simplify_bracket() -> None: + a = T.var("int32") + b = T.var("int32") + c = T.var("int32") + d = T.var("int32") + T.evaluate(a + b * (c + d)) + return simplify_bracket -def test_simplify_bracket(): - func = simplify_bracket - out_str = func.script(show_meta=True) - assert out_str.count("a + b * (c + d)") == 1 +def var_with_same_name(): + @T.prim_func + def var_with_same_name(a: T.handle) -> None: + A = T.match_buffer(a, (16, 16), "float32") + for i, j in T.grid(16, 16): + with T.block(): + vi, vj = T.axis.remap("SS", [i, j]) + A[vi, vj] = 0 + for i, j in T.grid(16, 16): + with T.block(): + vi, vj = T.axis.remap("SS", [i, j]) + A[vi, vj] = 0 -@T.prim_func -def var_with_same_name(a: T.handle) -> None: - A = T.match_buffer(a, (16, 16), "float32") - for i, j in T.grid(16, 16): - with T.block(): - vi, vj = T.axis.remap("SS", [i, j]) - A[vi, vj] = 0 - for i, j in T.grid(16, 16): - with T.block(): - vi, vj = T.axis.remap("SS", [i, j]) - A[vi, vj] = 0 + return var_with_same_name def test_same_name_var(): - func = var_with_same_name + func = var_with_same_name() out_str = func.script(tir_prefix="T", show_meta=True) rt_func = tvm.script.from_source(out_str) tvm.ir.assert_structural_equal(func, rt_func) @@ -3121,124 +2935,115 @@ def test_same_name_var(): assert out_str.find("i_") == -1 -@T.prim_func -def while_loop(a: T.handle, b: T.handle) -> None: - A = T.match_buffer(a, (16,), "float32") - B = T.match_buffer(b, (16,), "float32") - i = T.alloc_buffer((), "int32", scope="local") - for ii in range(16): - with T.block(): - vi = T.axis.S(16, ii) - B[vi] = 0 - while i[()] < 10: - for j in range(16): - B[j] += A[j] - +def while_loop(): + @T.prim_func + def while_loop(a: T.handle, b: T.handle) -> None: + A = T.match_buffer(a, (16,), "float32") + B = T.match_buffer(b, (16,), "float32") + i = T.alloc_buffer((), "int32", scope="local") + for ii in range(16): + with T.block(): + vi = T.axis.S(16, ii) + B[vi] = 0 + while i[()] < 10: + for j in range(16): + B[j] += A[j] -def test_while_loop(): - rt_func = tvm.script.from_source(while_loop.script(show_meta=True)) - tvm.ir.assert_structural_equal(while_loop, rt_func) + return while_loop # fmt: off -@T.prim_func -def primfunc_with_allocate_annotations(placeholder_28: T.handle, T_cast_6: T.handle) -> None: - # function attr dict - T.func_attr({"global_symbol": "tvmgen_default_fused_nn_max_pool2d_cast", "tir.noalias": True}) - placeholder_29 = T.match_buffer(placeholder_28, [1, 112, 112, 64], dtype="uint8", elem_offset=0, align=128, offset_factor=1) - T_cast_7 = T.match_buffer(T_cast_6, [1, 56, 56, 64], dtype="int16", elem_offset=0, align=128, offset_factor=1) - # body - tensor_2 = T.allocate([200704], "uint8", "global", annotations={"attr1_key": "attr1_value"}) - for ax0_ax1_fused_4 in T.serial(0, 56): - for ax2_4 in T.serial(0, 56): - for ax3_init in T.serial(0, 64): - T.store(tensor_2, (((ax0_ax1_fused_4*3584) + (ax2_4*64)) + ax3_init), T.uint8(0), True) - for rv0_rv1_fused_1, ax3_2 in T.grid(9, 64): - T.store(tensor_2, (((ax0_ax1_fused_4*3584) + (ax2_4*64)) + ax3_2), T.max(T.load("uint8", tensor_2, (((ax0_ax1_fused_4*3584) + (ax2_4*64)) + ax3_2)), T.if_then_else(((((ax0_ax1_fused_4*2) + T.floordiv(rv0_rv1_fused_1, 3)) < 112) and (((ax2_4*2) + T.floormod(rv0_rv1_fused_1, 3)) < 112)), T.load("uint8", placeholder_29.data, (((((ax0_ax1_fused_4*14336) + (T.floordiv(rv0_rv1_fused_1, 3)*7168)) + (ax2_4*128)) + (T.floormod(rv0_rv1_fused_1, 3)*64)) + ax3_2)), T.uint8(0), dtype="uint8")), True) - for ax0_ax1_fused_5 in T.serial(0, 56): - for ax2_5, ax3_3 in T.grid(56, 64): - T.store(T_cast_7.data, (((ax0_ax1_fused_5*3584) + (ax2_5*64)) + ax3_3), T.cast(T.load("uint8", tensor_2, (((ax0_ax1_fused_5*3584) + (ax2_5*64)) + ax3_3)), "int16"), True) +def primfunc_with_allocate_annotations(): + @T.prim_func + def primfunc_with_allocate_annotations(placeholder_28: T.handle, T_cast_6: T.handle) -> None: + # function attr dict + T.func_attr({"global_symbol": "tvmgen_default_fused_nn_max_pool2d_cast", "tir.noalias": True}) + placeholder_29 = T.match_buffer(placeholder_28, [802816], dtype="uint8", elem_offset=0, align=128, offset_factor=1) + T_cast_7 = T.match_buffer(T_cast_6, [200704], dtype="int16", elem_offset=0, align=128, offset_factor=1) + # body + tensor_2 = T.allocate([200704], "uint8", "global", annotations={"attr1_key": "attr1_value"}) + for ax0_ax1_fused_4 in T.serial(0, 56): + for ax2_4 in T.serial(0, 56): + for ax3_init in T.serial(0, 64): + tensor_2[(((ax0_ax1_fused_4*3584) + (ax2_4*64)) + ax3_init)] = T.uint8(0) + for rv0_rv1_fused_1, ax3_2 in T.grid(9, 64): + tensor_2[(((ax0_ax1_fused_4*3584) + (ax2_4*64)) + ax3_2)] = T.max(tensor_2[(((ax0_ax1_fused_4*3584) + (ax2_4*64)) + ax3_2)], T.if_then_else(((((ax0_ax1_fused_4*2) + T.floordiv(rv0_rv1_fused_1, 3)) < 112) and (((ax2_4*2) + T.floormod(rv0_rv1_fused_1, 3)) < 112)), placeholder_29[(((((ax0_ax1_fused_4*14336) + (T.floordiv(rv0_rv1_fused_1, 3)*7168)) + (ax2_4*128)) + (T.floormod(rv0_rv1_fused_1, 3)*64)) + ax3_2)], T.uint8(0), dtype="uint8")) + for ax0_ax1_fused_5 in T.serial(0, 56): + for ax2_5, ax3_3 in T.grid(56, 64): + T_cast_7[(((ax0_ax1_fused_5*3584) + (ax2_5*64)) + ax3_3)] = T.cast(tensor_2[(((ax0_ax1_fused_5*3584) + (ax2_5*64)) + ax3_3)], "int16") + + return primfunc_with_allocate_annotations # fmt: on -def test_primfunc_with_allocate_annotations(): - func = primfunc_with_allocate_annotations - rt_func = tvm.script.from_source(func.script(show_meta=True)) - tvm.ir.assert_structural_equal(func, rt_func, True) # fmt: off -@T.prim_func -def comm_reducer_single_reduce_group(a: T.handle, b: T.handle) -> None: - T.func_attr({"global_symbol": "main", "tir.noalias": True}) - threadIdx_x = T.env_thread("threadIdx.x") - A = T.match_buffer(a, [128, 128], dtype="float32") - for i in T.serial(0, 128): - T.launch_thread(threadIdx_x, 128) - reduce_temp0 = T.allocate([1], "float32", "local") - with T.attr(T.comm_reducer(lambda x, y: x + y, [T.float32(0)]), "reduce_scope", T.reinterpret(T.uint64(0), dtype="handle")): - T.evaluate(T.tvm_thread_allreduce(T.uint32(1), T.load("float32", A.data, i * 128 + threadIdx_x), True, reduce_temp0, threadIdx_x, dtype="handle")) - - -@T.prim_func -def comm_reducer_multiple_reduce_groups(a: T.handle, b: T.handle) -> None: - T.func_attr({"global_symbol": "main", "tir.noalias": True}) - threadIdx_x = T.env_thread("threadIdx.x") - A = T.match_buffer(a, [128, 128], dtype="float32") - for i in T.serial(0, 128): - T.launch_thread(threadIdx_x, 128) - reduce_temp0 = T.allocate([1], "float32", "local") - with T.attr(T.comm_reducer(lambda x0, x1, y0, y1: (T.Select((x1 >= y1), x0, y0), T.Select((x1 >= y1), x1, y1)), [T.int32(-1), T.min_value("float32")]), "reduce_scope", T.reinterpret(T.uint64(0), dtype="handle")): - T.evaluate(T.tvm_thread_allreduce(T.uint32(1), T.load("float32", A.data, i * 128 + threadIdx_x), True, reduce_temp0, threadIdx_x, dtype="handle")) - - -@T.prim_func -def multiple_commreducer() -> None: - normal_reduce_temp0 = T.buffer_decl([1], dtype="float32", strides=[1], scope="local") - normal_reduce_temp1 = T.buffer_decl([1], dtype="float32", strides=[1], scope="local") - reduce_temp0 = T.buffer_decl([1], dtype="float32", strides=[1], scope="local") - reduce_temp1 = T.buffer_decl([1], dtype="float32", strides=[1], scope="local") - for ax0_1 in T.thread_binding(0, 32, thread="threadIdx.x"): - with T.block("T_softmax_maxelem_cross_thread_reduction"): - T.attr(T.comm_reducer(lambda x, y: T.max(x, y), [T.min_value("float32")]), "reduce_scope", T.reinterpret(T.uint64(0), dtype="handle")) - T.evaluate(T.tvm_thread_allreduce(T.uint32(1), normal_reduce_temp0[0], True, reduce_temp0.data, ax0_1, dtype="handle")) - for ax0_1 in T.thread_binding(0, 32, thread="threadIdx.x"): - with T.block("T_softmax_expsum_cross_thread_reduction"): - T.attr(T.comm_reducer(lambda x, y: x + y, [T.float32(0)]), "reduce_scope", T.reinterpret(T.uint64(0), dtype="handle")) - T.evaluate(T.tvm_thread_allreduce(T.uint32(1), normal_reduce_temp1[0], True, reduce_temp1.data, ax0_1, dtype="handle")) -# fmt: on +def comm_reducer_single_reduce_group(): + @T.prim_func + def comm_reducer_single_reduce_group(a: T.handle, b: T.handle) -> None: + T.func_attr({"global_symbol": "main", "tir.noalias": True}) + threadIdx_x = T.env_thread("threadIdx.x") + A = T.match_buffer(a, [128 * 128], dtype="float32") + for i in T.serial(0, 128): + T.launch_thread(threadIdx_x, 128) + reduce_temp0 = T.allocate([1], "float32", "local") + with T.attr(T.comm_reducer(lambda x, y: x + y, [T.float32(0)]), "reduce_scope", T.reinterpret(T.uint64(0), dtype="handle")): + T.evaluate(T.tvm_thread_allreduce(T.uint32(1), A[i * 128 + threadIdx_x], True, reduce_temp0.data, threadIdx_x, dtype="handle")) + return comm_reducer_single_reduce_group -def test_primfunc_with_single_reduce_group_commreducer(): - func = comm_reducer_single_reduce_group - rt_func = tvm.script.from_source(func.script(show_meta=True)) - tvm.ir.assert_structural_equal(func, rt_func, True) +def comm_reducer_multiple_reduce_groups(): + @T.prim_func + def comm_reducer_multiple_reduce_groups(a: T.handle, b: T.handle) -> None: + T.func_attr({"global_symbol": "main", "tir.noalias": True}) + threadIdx_x = T.env_thread("threadIdx.x") + A = T.match_buffer(a, [128 * 128], dtype="float32") + for i in T.serial(0, 128): + T.launch_thread(threadIdx_x, 128) + reduce_temp0 = T.allocate([1], "float32", "local") + with T.attr(T.comm_reducer(lambda x0, x1, y0, y1: (T.Select((x1 >= y1), x0, y0), T.Select((x1 >= y1), x1, y1)), [T.int32(-1), T.min_value("float32")]), "reduce_scope", T.reinterpret(T.uint64(0), dtype="handle")): + T.evaluate(T.tvm_thread_allreduce(T.uint32(1), A[i * 128 + threadIdx_x], True, reduce_temp0.data, threadIdx_x, dtype="handle")) -def test_primfunc_with_multiple_reduce_group_commreducer(): - func = comm_reducer_multiple_reduce_groups - rt_func = tvm.script.from_source(func.script(show_meta=True)) - tvm.ir.assert_structural_equal(func, rt_func, True) + return comm_reducer_multiple_reduce_groups -def test_primfunc_with_multiple_commreducer(): - func = multiple_commreducer - rt_func = tvm.script.from_source(func.script(show_meta=True)) - tvm.ir.assert_structural_equal(func, rt_func, True) +def multiple_commreducer(): + @T.prim_func + def multiple_commreducer() -> None: + normal_reduce_temp0 = T.buffer_decl([1], dtype="float32", strides=[1], scope="local") + normal_reduce_temp1 = T.buffer_decl([1], dtype="float32", strides=[1], scope="local") + reduce_temp0 = T.buffer_decl([1], dtype="float32", strides=[1], scope="local") + reduce_temp1 = T.buffer_decl([1], dtype="float32", strides=[1], scope="local") + for ax0_1 in T.thread_binding(0, 32, thread="threadIdx.x"): + with T.block("T_softmax_maxelem_cross_thread_reduction"): + T.attr(T.comm_reducer(lambda x, y: T.max(x, y), [T.min_value("float32")]), "reduce_scope", T.reinterpret(T.uint64(0), dtype="handle")) + T.evaluate(T.tvm_thread_allreduce(T.uint32(1), normal_reduce_temp0[0], True, reduce_temp0.data, ax0_1, dtype="handle")) + for ax0_1 in T.thread_binding(0, 32, thread="threadIdx.x"): + with T.block("T_softmax_expsum_cross_thread_reduction"): + T.attr(T.comm_reducer(lambda x, y: x + y, [T.float32(0)]), "reduce_scope", T.reinterpret(T.uint64(0), dtype="handle")) + T.evaluate(T.tvm_thread_allreduce(T.uint32(1), normal_reduce_temp1[0], True, reduce_temp1.data, ax0_1, dtype="handle")) + + return multiple_commreducer +# fmt: on -@T.prim_func def func_div_mod(): - a = T.var("int32") - b = T.var("int32") - T.evaluate(a // b) - T.evaluate(a % b) - T.evaluate(a / b) - T.evaluate(T.truncmod(a, b)) + @T.prim_func + def func_div_mod(): + a = T.var("int32") + b = T.var("int32") + T.evaluate(a // b) + T.evaluate(a % b) + T.evaluate(a / b) + T.evaluate(T.truncmod(a, b)) + + return func_div_mod def test_div_mod(): - func = func_div_mod + func = func_div_mod() rt_func = tvm.script.from_source(func.script()) tvm.ir.assert_structural_equal(func, rt_func, True) @@ -3248,130 +3053,146 @@ def test_div_mod(): assert isinstance(func.body[3].value, tvm.tir.Mod) -@T.prim_func -def loop_extent_dependent(a: T.handle) -> None: - A = T.match_buffer(a, [], dtype="int32") - for i in T.serial(0, 128): - for j in T.serial(0, i): - A[()] = A[()] + j - - -def test_loop_extent_dependent(): - func = loop_extent_dependent - rt_func = tvm.script.from_source(func.script(show_meta=True)) - tvm.ir.assert_structural_equal(func, rt_func, True) - - -@T.prim_func -def nontrivial_range_axis(a: T.handle) -> None: - A = T.match_buffer(a, (10), "float32") - for i in range(10): - with T.block("block"): - vi = T.axis.spatial((1, 11), i + 1) - A[vi - 1] = A[vi - 1] + 1.0 - - -def test_nontrivial_range_axis(): - func = nontrivial_range_axis - rt_func = tvm.script.from_source(func.script(show_meta=True)) - tvm.ir.assert_structural_equal(func, rt_func, True) - +def loop_extent_dependent(): + @T.prim_func + def loop_extent_dependent(a: T.handle) -> None: + A = T.match_buffer(a, [], dtype="int32") + for i in T.serial(0, 128): + for j in T.serial(0, i): + A[()] = A[()] + j -@T.prim_func -def func_with_target_spec_by_config() -> None: - T.func_attr( - { - "kTarget": T.target( - { - "max_num_threads": 1024, - "arch": "sm_70", - "thread_warp_size": 32, - "kind": "cuda", - "tag": "", - "keys": ["cuda", "gpu"], - } - ) - } - ) - T.evaluate(0) + return loop_extent_dependent -@T.prim_func -def func_with_target_spec_by_str() -> None: - T.func_attr({"kTarget": T.target("nvidia/nvidia-a100")}) - T.evaluate(0) +def nontrivial_range_axis(): + @T.prim_func + def nontrivial_range_axis(a: T.handle) -> None: + A = T.match_buffer(a, (10), "float32") + for i in range(10): + with T.block("block"): + vi = T.axis.spatial((1, 11), i + 1) + A[vi - 1] = A[vi - 1] + 1.0 + return nontrivial_range_axis -def test_func_with_target_spec_by_config(): - func = func_with_target_spec_by_config - rt_func = tvm.script.from_source(func.script(show_meta=True)) - tvm.ir.assert_structural_equal(func, rt_func, True) +def func_with_target_spec_by_config(): + @T.prim_func + def func_with_target_spec_by_config() -> None: + T.func_attr( + { + "kTarget": T.target( + { + "max_num_threads": 1024, + "arch": "sm_70", + "thread_warp_size": 32, + "kind": "cuda", + "tag": "", + "keys": ["cuda", "gpu"], + } + ) + } + ) + T.evaluate(0) -def test_func_with_target_spec_by_str(): - func = func_with_target_spec_by_str - rt_func = tvm.script.from_source(func.script(show_meta=True)) - tvm.ir.assert_structural_equal(func, rt_func, True) + return func_with_target_spec_by_config -@T.prim_func -def func_root_attr(): - with T.block("root"): - T.block_attr({"a": "0"}) +def func_with_target_spec_by_str(): + @T.prim_func + def func_with_target_spec_by_str() -> None: + T.func_attr({"kTarget": T.target("nvidia/nvidia-a100")}) T.evaluate(0) + return func_with_target_spec_by_str -def test_root_attr(): - func = func_root_attr - rt_func = tvm.script.from_source(func.script(show_meta=True)) - tvm.ir.assert_structural_equal(func, rt_func, True) - - -@T.prim_func -def func_T_ptr_let_statement( - args: T.handle, arg_type_ids_handle: T.Ptr[T.int32], num_args: T.int32 -) -> None: - # The T.Ptr declaration in the parameter list should parse - # correctly, and should be usable as the data pointer in a buffer. - arg_type_ids = T.buffer_decl([2], dtype="int32", data=arg_type_ids_handle) - arg0: T.handle = T.tvm_struct_get(args, 0, 12, dtype="handle") - arg1: T.handle = T.tvm_struct_get(args, 1, 12, dtype="handle") +def func_root_attr(): + @T.prim_func + def func_root_attr(): + with T.block("root"): + T.block_attr({"a": "0"}) + T.evaluate(0) - # Functions that return a "handle" can be assigned to a T.Ptr - # variable. A variable annotated with T.Ptr still has dtype of - # T.handle, but has type annotation as a pointer type. - A_data: T.Ptr[T.float32] = T.tvm_struct_get(arg0, 0, 1, dtype="handle") + return func_root_attr - # The buffer declaration has a data pointer defined earlier in - # this function. It should only be defined after the data pointer - # has been defined, and should not be hoisted into the header of - # the function as other buffer_decl statements can be. - A = T.buffer_decl([1024], dtype="float32", data=A_data) - B_data: T.Ptr[T.float32] = T.tvm_struct_get(arg1, 0, 1, dtype="handle") - B = T.buffer_decl([1024], dtype="float32", data=B_data) - B[0] = A[0] +def func_T_ptr_let_statement(): + @T.prim_func + def func_T_ptr_let_statement( + args: T.handle, arg_type_ids_handle: T.Ptr[T.int32], num_args: T.int32 + ) -> None: + # The T.Ptr declaration in the parameter list should parse + # correctly, and should be usable as the data pointer in a buffer. + arg_type_ids = T.buffer_decl([2], dtype="int32", data=arg_type_ids_handle) + arg0: T.handle = T.tvm_struct_get(args, 0, 12, dtype="handle") + arg1: T.handle = T.tvm_struct_get(args, 1, 12, dtype="handle") -def test_T_ptr_let_statement(): - func = func_T_ptr_let_statement - rt_func = tvm.script.from_source(func.script(show_meta=True)) - tvm.ir.assert_structural_equal(func, rt_func, True) + # Functions that return a "handle" can be assigned to a T.Ptr + # variable. A variable annotated with T.Ptr still has dtype of + # T.handle, but has type annotation as a pointer type. + A_data: T.Ptr[T.float32] = T.tvm_struct_get(arg0, 0, 1, dtype="handle") + # The buffer declaration has a data pointer defined earlier in + # this function. It should only be defined after the data pointer + # has been defined, and should not be hoisted into the header of + # the function as other buffer_decl statements can be. + A = T.buffer_decl([1024], dtype="float32", data=A_data) + B_data: T.Ptr[T.float32] = T.tvm_struct_get(arg1, 0, 1, dtype="handle") + B = T.buffer_decl([1024], dtype="float32", data=B_data) -@T.prim_func -def func_T_ptr_allocate() -> None: - A_data: T.Ptr[T.float32] = T.allocate([1024], "float32", "global") - A = T.buffer_decl([1024], dtype="float32", data=A_data) + B[0] = A[0] - A[0] = 0.0 + return func_T_ptr_let_statement -def test_T_ptr_allocate(): - func = func_T_ptr_allocate - rt_func = tvm.script.from_source(func.script(show_meta=True)) - tvm.ir.assert_structural_equal(func, rt_func, True) +def func_T_ptr_allocate(): + @T.prim_func + def func_T_ptr_allocate() -> None: + A = T.allocate([1024], "float32", "global") + A[0] = 0.0 + + return func_T_ptr_allocate + + +ir_generator = tvm.testing.parameter( + opt_gemm_normalize, + opt_gemm_lower, + opt_gemm_mod_host, + opt_conv_tensorcore_normalize, + opt_conv_tensorcore_lower, + opt_conv_tensorcore_mod_host, + vthread_func, + matmul, + module_const, + constant, + rank0, + rank0_block, + select, + minmax, + abs, + constant_folding, + simplify_bracket, + while_loop, + primfunc_with_allocate_annotations, + comm_reducer_single_reduce_group, + comm_reducer_multiple_reduce_groups, + multiple_commreducer, + loop_extent_dependent, + nontrivial_range_axis, + func_with_target_spec_by_config, + func_with_target_spec_by_str, + func_root_attr, + func_T_ptr_let_statement, + func_T_ptr_allocate, +) + + +def test_roundtrip(ir_generator): + original = ir_generator() + after_roundtrip = tvm.script.from_source(original.script(show_meta=True)) + tvm.ir.assert_structural_equal(original, after_roundtrip, True) @T.prim_func diff --git a/vta/python/vta/transform.py b/vta/python/vta/transform.py index 383841f19e34..1e8247c6e135 100644 --- a/vta/python/vta/transform.py +++ b/vta/python/vta/transform.py @@ -156,15 +156,45 @@ def CPUAccessRewrite(): """ def _ftransform(f, mod, ctx): - rw_info = {} env = get_env() + var_remap = {} + buf_remap = {} + + def find_var_remap(old_var): + if old_var in var_remap: + return var_remap[old_var] + + new_var = tvm.tir.Var(old_var.name + "_ptr", dtype=old_var.type_annotation) + var_remap[old_var] = new_var + return new_var + + def find_buf_remap(old_buf): + if old_buf in buf_remap: + return buf_remap[old_buf] + + new_var = find_var_remap(old_buf.data) + new_buf = tvm.tir.decl_buffer( + shape=old_buf.shape, + dtype=old_buf.dtype, + data=new_var, + strides=old_buf.strides, + elem_offset=old_buf.elem_offset, + scope=old_buf.scope, + data_alignment=old_buf.data_alignment, + offset_factor=old_buf.offset_factor, + buffer_type="auto_broadcast" if (old_buf.buffer_type == 2) else "", + axis_separators=old_buf.axis_separators, + ) + buf_remap[old_buf] = new_buf + return new_buf + def _post_order(op): if isinstance(op, tvm.tir.Allocate): buffer_var = op.buffer_var - if not buffer_var in rw_info: + if buffer_var not in var_remap: return None - new_var = rw_info[buffer_var] + new_var = var_remap[buffer_var] let_stmt = tvm.tir.LetStmt( new_var, tvm.tir.call_extern( @@ -173,33 +203,31 @@ def _post_order(op): op.body, ) alloc = tvm.tir.Allocate(buffer_var, op.dtype, op.extents, op.condition, let_stmt) - del rw_info[buffer_var] + del var_remap[buffer_var] + bufs_to_delete = [ + old_buf for old_buf in buf_remap if old_buf.data.same_as(buffer_var) + ] + for buf in bufs_to_delete: + del buf_remap[buf] return alloc - if isinstance(op, tvm.tir.Load): - buffer_var = op.buffer_var - if not buffer_var in rw_info: - rw_info[buffer_var] = te.var(buffer_var.name + "_ptr", "handle") - new_var = rw_info[buffer_var] - return tvm.tir.Load(op.dtype, new_var, op.index) - if isinstance(op, tvm.tir.Store): - buffer_var = op.buffer_var - if not buffer_var in rw_info: - rw_info[buffer_var] = te.var(buffer_var.name + "_ptr", "handle") - new_var = rw_info[buffer_var] - return tvm.tir.Store(new_var, op.value, op.index) + + if isinstance(op, tvm.tir.BufferLoad): + return tvm.tir.BufferLoad(find_buf_remap(op.buffer), op.indices) + + if isinstance(op, tvm.tir.BufferStore): + return tvm.tir.BufferStore(find_buf_remap(op.buffer), op.value, op.indices) + raise RuntimeError("not reached") stmt_in = f.body stmt = tvm.tir.stmt_functor.ir_transform( - stmt_in, None, _post_order, ["tir.Allocate", "tir.Load", "tir.Store"] + stmt_in, None, _post_order, ["tir.Allocate", "tir.BufferLoad", "tir.BufferStore"] ) - for buffer_var, new_var in rw_info.items(): + for old_var, new_var in var_remap.items(): stmt = tvm.tir.LetStmt( new_var, - tvm.tir.call_extern( - "handle", "VTABufferCPUPtr", env.dev.command_handle, buffer_var - ), + tvm.tir.call_extern("handle", "VTABufferCPUPtr", env.dev.command_handle, old_var), stmt, ) return f.with_body(stmt) @@ -919,8 +947,8 @@ def _flatten_loop(src_coeff, dst_coeff, extents): loop_body = loop_body.body nest_size += 1 # Get the src/dst arguments - dst_var = loop_body.buffer_var - dst_idx = loop_body.index + dst_var = loop_body.buffer.data + dst_idx = loop_body.indices[0] # Derive loop variables and extents tmp_body = stmt.body indices = [] @@ -963,7 +991,7 @@ def _flatten_loop(src_coeff, dst_coeff, extents): raise RuntimeError( "Function call not recognized %s" % (loop_body.value.name) ) - elif isinstance(loop_body.value, tvm.tir.Load): + elif isinstance(loop_body.value, tvm.tir.BufferLoad): alu_opcode = env.dev.ALU_OPCODE_SHR lhs = loop_body.value rhs = tvm.tir.const(0, "int32") @@ -979,20 +1007,20 @@ def _flatten_loop(src_coeff, dst_coeff, extents): use_imm = False imm_val = None if isinstance(rhs, tvm.tir.IntImm): - assert lhs.buffer_var.same_as(dst_var) - src_coeff = tvm.arith.detect_linear_equation(lhs.index, indices) + assert lhs.buffer.data.same_as(dst_var) + src_coeff = tvm.arith.detect_linear_equation(lhs.indices[0], indices) use_imm = True imm_val = rhs if isinstance(lhs, tvm.tir.IntImm): - assert rhs.buffer_var.same_as(dst_var) - src_coeff = tvm.arith.detect_linear_equation(rhs.index, indices) + assert rhs.buffer.data.same_as(dst_var) + src_coeff = tvm.arith.detect_linear_equation(rhs.indices[0], indices) use_imm = True imm_val = lhs if imm_val is None: imm_val = 0 - assert lhs.buffer_var.same_as(dst_var) and rhs.buffer_var.same_as(dst_var) - src_lhs_coeff = tvm.arith.detect_linear_equation(lhs.index, indices) - src_rhs_coeff = tvm.arith.detect_linear_equation(rhs.index, indices) + assert lhs.buffer.data.same_as(dst_var) and rhs.buffer.data.same_as(dst_var) + src_lhs_coeff = tvm.arith.detect_linear_equation(lhs.indices[0], indices) + src_rhs_coeff = tvm.arith.detect_linear_equation(rhs.indices[0], indices) # Determine which side has the same coefficients lhs_equal = True rhs_equal = True @@ -1058,7 +1086,12 @@ def _flatten_loop(src_coeff, dst_coeff, extents): for idx, extent in enumerate(extents): irb.emit( tvm.tir.call_extern( - "int32", "VTAUopLoopBegin", extent, dst_coeff[idx], src_coeff[idx], 0 + "int32", + "VTAUopLoopBegin", + extent, + dst_coeff[idx], + src_coeff[idx], + 0, ) ) use_imm = int(use_imm)