diff --git a/CMakeLists.txt b/CMakeLists.txt index e16c09f13e9..5b9e2df86b5 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -177,6 +177,7 @@ list(APPEND NVFUSER_SRCS ${NVFUSER_SRCS_DIR}/type_inference.cpp ${NVFUSER_SRCS_DIR}/type_promotion.cpp ${NVFUSER_SRCS_DIR}/fusion_segmenter.cpp + ${NVFUSER_SRCS_DIR}/tensor_metadata.cpp ${NVFUSER_SRCS_DIR}/tensor_view.cpp ${NVFUSER_SRCS_DIR}/transform_iter.cpp ${NVFUSER_SRCS_DIR}/transform_replay.cpp diff --git a/csrc/codegen.cpp b/csrc/codegen.cpp index aea4a78705d..3580ad6eb09 100644 --- a/csrc/codegen.cpp +++ b/csrc/codegen.cpp @@ -575,7 +575,7 @@ class CudaKernelGenerator : private kir::ConstIrVisitor { [&](auto&& dtype) { using T = std::decay_t; if constexpr (std::is_same_v) { - for (auto& [name, _] : dtype.types) { + for (auto& name : dtype.field_names) { indent() << gen(gop->output(0)) << "." << name << " = " << gen(gop->in()) << "." << name << ";\n"; } @@ -1376,7 +1376,7 @@ class CudaKernelGenerator : private kir::ConstIrVisitor { ldst->out()->dtype(), " = ", ldst->in()->dtype()); - for (auto& [name, _] : out_type.types) { + for (auto& name : out_type.field_names) { TORCH_INTERNAL_ASSERT( in_type.types.find(name) != in_type.types.end(), "Mismatched field in struct assignment: ", diff --git a/csrc/device_lower/pass/replace_size.cpp b/csrc/device_lower/pass/replace_size.cpp index 522a3cb541e..36581612509 100644 --- a/csrc/device_lower/pass/replace_size.cpp +++ b/csrc/device_lower/pass/replace_size.cpp @@ -203,7 +203,7 @@ void replaceSymbolicSizes(Fusion* fusion) { if (tensor_dim_map.find(orig_size) == tensor_dim_map.end() && !orig_size->isFusionInput() && !orig_size->isConstScalar()) { std::stringstream ss; - ss << "T" << tv->name() << ".size[" << dim++ << "]"; + ss << "T" << tv->name() << ".logical_size[" << dim++ << "]"; tensor_dim_map[orig_size] = IrBuilder::create( ss.str(), orig_size->getDataType().value()); } else { diff --git a/csrc/executor.cpp b/csrc/executor.cpp index fa5febaa8fb..79f1c10946a 100644 --- a/csrc/executor.cpp +++ b/csrc/executor.cpp @@ -1838,8 +1838,8 @@ float FusionExecutor::runRtc( Struct concrete_value; concrete_value["data"] = PolymorphicValue( Pointer(input.data_ptr(), aten_to_data_type(input.scalar_type()))); - concrete_value["size"] = PolymorphicValue(input.sizes().vec()); - concrete_value["stride"] = PolymorphicValue(input.strides().vec()); + concrete_value["logical_size"] = PolymorphicValue(input.sizes().vec()); + concrete_value["alloc_stride"] = PolymorphicValue(input.strides().vec()); data.emplace_back(getTensorArgBuffer(concrete_value, index_type)); pointers.emplace_back(data.back().data()); } diff --git a/csrc/executor_kernel_arg.cpp b/csrc/executor_kernel_arg.cpp index 46a66371ec4..e920898e93c 100644 --- a/csrc/executor_kernel_arg.cpp +++ b/csrc/executor_kernel_arg.cpp @@ -14,303 +14,6 @@ namespace nvfuser { -namespace { - -// Forward traverse from rFactor domain to allocation domain, compute frontier -// sizes and strides, validate that splits are divisible and merges are -// contiguous, and update active_ids_ correspondingly. -class ForwardTraverseFromRFactorToAlloc { - ExpressionEvaluator& ee_; - std::unordered_map>& active_ids_; - - void handle(Split* split) { - auto in = split->in(); - auto inner = split->inner(); - auto outer = split->outer(); - auto in_it = active_ids_.find(in); - // TORCH_INTERNAL_ASSERT(in_it != active_ids_.end()) - if (in_it == active_ids_.end()) { - // TODO: see [Allocation domain on both side of rFactor] - return; - } - auto [in_size, in_stride] = in_it->second; - auto factor = ee_.evaluate(split->factor()).as(); - TORCH_INTERNAL_ASSERT( - in_size % factor == 0, - "The rFactor domain and allocation domain of fusion input/output ", - "tensors must be a one-to-one map, therefore, ", - "non-divisible split is not allowed in allocation domain"); - TORCH_INTERNAL_ASSERT(active_ids_.erase(in) == 1); - TORCH_INTERNAL_ASSERT( - active_ids_ - .emplace(inner, std::pair{factor, in_stride}) - .second); - TORCH_INTERNAL_ASSERT(active_ids_ - .emplace( - outer, - std::pair{ - in_size / factor, in_stride * factor}) - .second); - } - - void handle(Merge* merge) { - auto inner = merge->inner(); - auto outer = merge->outer(); - auto out = merge->out(); - auto inner_it = active_ids_.find(inner); - auto outer_it = active_ids_.find(outer); - // TORCH_INTERNAL_ASSERT(inner_it != active_ids_.end()) - // TORCH_INTERNAL_ASSERT(outer_it != active_ids_.end()) - if (inner_it == active_ids_.end() || outer_it == active_ids_.end()) { - // TODO: see [Allocation domain on both side of rFactor] - return; - } - auto [inner_size, inner_stride] = inner_it->second; - auto [outer_size, outer_stride] = outer_it->second; - TORCH_INTERNAL_ASSERT( - inner_stride * inner_size == outer_stride, - "The rFactor domain and allocation domain of fusion input/output ", - "tensors must be a one-to-one map, therefore, ", - "merging of discontiguous dimensions is not allowed in allocation domain"); - TORCH_INTERNAL_ASSERT(active_ids_.erase(inner) == 1); - TORCH_INTERNAL_ASSERT(active_ids_.erase(outer) == 1); - TORCH_INTERNAL_ASSERT(active_ids_ - .emplace( - out, - std::pair{ - inner_size * outer_size, inner_stride}) - .second); - } - - void handle(Expr* expr) { - if (auto split = dynamic_cast(expr)) { - handle(split); - } else if (auto merge = dynamic_cast(expr)) { - handle(merge); - } else { - TORCH_INTERNAL_ASSERT( - false, "Unsupported transormation in allocation domain"); - } - } - - public: - ForwardTraverseFromRFactorToAlloc( - ExpressionEvaluator& ee, - std::unordered_map>& active_ids) - : ee_(ee), active_ids_(active_ids) {} - - void run( - TensorView* tv, - const std::vector& rfactor, - const std::vector& alloc) { - auto forward_exprs = StmtSort::getExprsBetween( - tv->fusion(), - {rfactor.begin(), rfactor.end()}, - {alloc.begin(), alloc.end()}); - for (auto expr : forward_exprs) { - handle(expr); - } - } -}; - -// Similar to ForwardTraverseFromRFactorToAlloc, but in the opposite direction. -class BackwardTraverseFromRFactorToAlloc { - at::Tensor tensor_; - ExpressionEvaluator& ee_; - std::unordered_map>& active_ids_; - - void handle(Split* split) { - auto in = split->in(); - auto inner = split->inner(); - auto outer = split->outer(); - auto inner_it = active_ids_.find(inner); - auto outer_it = active_ids_.find(outer); - // TORCH_INTERNAL_ASSERT(inner_it != active_ids_.end()) - // TORCH_INTERNAL_ASSERT(outer_it != active_ids_.end()) - if (inner_it == active_ids_.end() || outer_it == active_ids_.end()) { - // TODO: see [Allocation domain on both side of rFactor] - return; - } - auto [inner_size, inner_stride] = inner_it->second; - auto [outer_size, outer_stride] = outer_it->second; - TORCH_INTERNAL_ASSERT( - inner_stride * inner_size == outer_stride, - "The rFactor domain and allocation domain of fusion input/output ", - "tensors must be a one-to-one map, therefore, ", - "splitting one dimension into discontiguous dimensions is not allowed in allocation domain"); - TORCH_INTERNAL_ASSERT(active_ids_.erase(inner) == 1); - TORCH_INTERNAL_ASSERT(active_ids_.erase(outer) == 1); - TORCH_INTERNAL_ASSERT(active_ids_ - .emplace( - in, - std::pair{ - inner_size * outer_size, inner_stride}) - .second); - } - - void handle(Merge* merge) { - auto inner = merge->inner(); - auto outer = merge->outer(); - auto out = merge->out(); - auto factor = ee_.evaluate(inner->extent()).as(); - auto out_it = active_ids_.find(out); - // TORCH_INTERNAL_ASSERT(out_it != active_ids_.end()) - if (out_it == active_ids_.end()) { - // TODO: see [Allocation domain on both side of rFactor] - return; - } - auto [out_size, out_stride] = out_it->second; - TORCH_INTERNAL_ASSERT( - out_size % factor == 0, - "The rFactor domain and allocation domain of fusion input/output ", - "tensors must be a one-to-one map, therefore, ", - "the size of the output must divisible by the size of inner dimension"); - TORCH_INTERNAL_ASSERT(active_ids_.erase(out) == 1); - TORCH_INTERNAL_ASSERT( - active_ids_ - .emplace(inner, std::pair{factor, out_stride}) - .second); - TORCH_INTERNAL_ASSERT(active_ids_ - .emplace( - outer, - std::pair{ - out_size / factor, out_stride * factor}) - .second); - } - - void handle(Expr* expr) { - if (auto split = dynamic_cast(expr)) { - handle(split); - } else if (auto merge = dynamic_cast(expr)) { - handle(merge); - } else { - TORCH_INTERNAL_ASSERT( - false, "Unsupported transormation in allocation domain"); - } - } - - public: - BackwardTraverseFromRFactorToAlloc( - ExpressionEvaluator& ee, - std::unordered_map>& active_ids) - : ee_(ee), active_ids_(active_ids) {} - - void run( - TensorView* tv, - const std::vector& rfactor, - const std::vector& alloc) { - auto backward_exprs = StmtSort::getExprsBetween( - tv->fusion(), - {alloc.begin(), alloc.end()}, - {rfactor.begin(), rfactor.end()}); - std::reverse(backward_exprs.begin(), backward_exprs.end()); - for (auto expr : backward_exprs) { - handle(expr); - } - } -}; - -} // namespace - -// Given an ATen tensor, whose sizes and strides are w.r.t to the rFactor domain -// of its corresponding TensorView, compute the sizes and strides of the tensor -// with respect to its allocation domain. -// For example, if the rFactor domain is [I1, I2], and the allocation domain is -// [I2*I1], and the tensor's size is [5, 3] and stride is [2, 10], then the -// resulting size will be [15] and stride will be [2] -// Another example, if the rFactor domain is [I1*I2] and the allocation domain -// is [I1, I2], and the tensor's size is [15] and stride is [7], and the extent -// of I2 is 5, then the resulting size will be [3, 5] and stride will be [35, 7] -std::vector> -inferAndValidateAllocationSizesAndStrides( - const at::Tensor& tensor, - TensorView* tv, - ExpressionEvaluator& ee) { - if (tv == nullptr || !tv->hasAllocation()) { - // When tv is nullptr, or tv does not have allocation, the given sizes and - // strides should already be in the target format. So nothing to do here. - std::vector> result; - for (auto i : c10::irange(tensor.dim())) { - result.emplace_back(tensor.size(i), tensor.stride(i)); - } - return result; - } - const auto& alloc = - TensorDomain::noReductions(tv->getMaybeAllocationDomain()); - const auto& rfactor = TensorDomain::noReductions(tv->getMaybeRFactorDomain()); - - // active IDs and their shape and stride - std::unordered_map> active_ids; - TORCH_INTERNAL_ASSERT((int64_t)rfactor.size() == tensor.dim()); - for (int64_t i : c10::irange((int64_t)rfactor.size())) { - auto rf_id = rfactor.at(i); - active_ids[rf_id] = {tensor.size(i), tensor.stride(i)}; - } - - ForwardTraverseFromRFactorToAlloc(ee, active_ids).run(tv, rfactor, alloc); - BackwardTraverseFromRFactorToAlloc(ee, active_ids).run(tv, rfactor, alloc); - - // Now active_ids should contain the final sizes and strides, unordered. We - // need to put them to the correct order. - std::vector> sizes_strides; - sizes_strides.reserve(alloc.size()); - for (auto i : c10::irange(alloc.size())) { - auto id = alloc.at(i); - sizes_strides.emplace_back(active_ids.at(id)); - } - // Validate final sizes and strides with contiguity - int64_t contiguous_stride = 1; - std::vector> contiguity = tv->getContiguity(); - for (int64_t i = (int64_t)sizes_strides.size() - 1; i >= 0; i--) { - if (alloc.at(i)->isBroadcast()) { - continue; - } - while (!contiguity.back().has_value()) { - contiguity.pop_back(); - } - auto [size, stride] = sizes_strides.at(i); - TORCH_INTERNAL_ASSERT(!contiguity.empty()); - auto last_contiguity = contiguity.back(); - TORCH_INTERNAL_ASSERT( - last_contiguity.has_value(), - "I don't think this check makes sense, but unfortunately ", - "clang-tidy is not smart enough to infer from the context that this is always true."); - if (*last_contiguity) { - TORCH_CHECK( - stride == contiguous_stride, - "Stride mismatch with contiguity info. ", - "tv: ", - tv->toString(), - " allocation domain: ", - ir_utils::toString(tv->getMaybeAllocationDomain()), - " dim: ", - i, - " expected stride: ", - contiguous_stride, - " actual stride: ", - stride); - } - contiguous_stride = stride * size; - contiguity.pop_back(); - } - TORCH_INTERNAL_ASSERT( - contiguity.empty(), - "The size of contiguity mismatch with the dimensionality of allocation domain"); - // Validate that for expanded broadcast, the stride must be zero. - for (int64_t i : c10::irange((int64_t)sizes_strides.size())) { - if (auto alloc_id = alloc.at(i); alloc_id->hasExpandedExtent()) { - auto [_, stride] = sizes_strides.at(i); - TORCH_CHECK( - stride == 0, - "Expecting an expanded dimension on dimension ", - i, - " but found stride ", - stride); - } - } - return sizes_strides; -} - PrimDataType TensorArgAbstract::getSmallestIndexType() const { KernelIndexTypeCompute index_type_helper; for (const auto dim_i : c10::irange(tensor_.ndimension())) { @@ -637,8 +340,8 @@ std::vector getTensorArgBuffer( auto struct_ = metadata.as(); std::vector buffer; void* ptr = (void*)struct_["data"]; - std::vector sizes = (std::vector)struct_["size"]; - std::vector strides = (std::vector)struct_["stride"]; + std::vector sizes = (std::vector)struct_["logical_size"]; + std::vector strides = (std::vector)struct_["alloc_stride"]; if (index_type == PrimDataType::Int) { buffer.reserve( sizeof(ptr) + sizeof(int64_t) * (sizes.size() + strides.size())); @@ -686,10 +389,8 @@ std::vector getKernelArgument( (std::byte*)tensor.data_ptr(), (std::byte*)tensor.data_ptr() + tensor.element_size()); } else { - auto resolved_arg = getTensorArg(tensor, tv, ee, index_type); - return std::vector( - (std::byte*)resolved_arg->arg(), - (std::byte*)resolved_arg->arg() + resolved_arg->argSize()); + auto metadata = ee.evaluate(IrBuilder::metadataExpr(tv)); + return getTensorArgBuffer(metadata, index_type); } } else if (isIntegralType(parameter->dtype())) { int64_t v = pv.as(); diff --git a/csrc/executor_kernel_arg.h b/csrc/executor_kernel_arg.h index 551f1247a2b..c04d75a04db 100644 --- a/csrc/executor_kernel_arg.h +++ b/csrc/executor_kernel_arg.h @@ -278,13 +278,6 @@ struct TensorArgAbstract : ArgAbstract { } }; -// TODO: move this to GetMetaData::evaluate -std::vector> -inferAndValidateAllocationSizesAndStrides( - const at::Tensor& tensor, - TensorView* tv, - ExpressionEvaluator& ee); - // TODO: remove this template struct TensorArg : public TensorArgAbstract { @@ -297,22 +290,6 @@ struct TensorArg : public TensorArgAbstract { for (const auto i : c10::irange(tensor.ndimension())) { instance_.setSize(i, (typename TENSOR_TYPE::index_type)tensor.size(i)); } - inferSetAndValidateStrides(tensor, tv, eval); - } - - void inferSetAndValidateStrides( - const at::Tensor& tensor, - TensorView* tv, - ExpressionEvaluator& eval) { - auto sizes_strides = - inferAndValidateAllocationSizesAndStrides(tensor, tv, eval); - TORCH_INTERNAL_ASSERT( - (size_t)instance_.nAllocationDims() == sizes_strides.size()); - for (auto i : c10::irange((int64_t)sizes_strides.size())) { - alloc_sizes.at(i) = sizes_strides.at(i).first; - using stride_t = typename TENSOR_TYPE::index_type; - instance_.setStride(i, (stride_t)sizes_strides.at(i).second); - } } int64_t getAllocRank() const override { diff --git a/csrc/executor_utils.cpp b/csrc/executor_utils.cpp index 110519b0401..6dcd68f4733 100644 --- a/csrc/executor_utils.cpp +++ b/csrc/executor_utils.cpp @@ -610,10 +610,10 @@ void validateAlignedVectorizeExtents( void validateAlignedVectorizedFusionInputOutput( const at::Tensor& aten_tensor, int word_size, - TensorView* tv) { - ExpressionEvaluator eval; - auto sizes_strides = - inferAndValidateAllocationSizesAndStrides(aten_tensor, tv, eval); + TensorView* tv, + ExpressionEvaluator eval) { + eval.bind(tv, aten_tensor); + auto metadata = eval.evaluate(IrBuilder::metadataExpr(tv)); std::vector no_reduction_to_full; for (int64_t i : @@ -623,7 +623,11 @@ void validateAlignedVectorizedFusionInputOutput( no_reduction_to_full.emplace_back(i); } } - TORCH_INTERNAL_ASSERT(sizes_strides.size() == no_reduction_to_full.size()); + + auto sizes = std::vector(metadata["alloc_size"]); + auto strides = std::vector(metadata["alloc_stride"]); + TORCH_INTERNAL_ASSERT(sizes.size() == no_reduction_to_full.size()); + TORCH_INTERNAL_ASSERT(strides.size() == no_reduction_to_full.size()); TORCH_INTERNAL_ASSERT( reinterpret_cast(aten_tensor.data_ptr()) % @@ -643,8 +647,9 @@ void validateAlignedVectorizedFusionInputOutput( // domain must have stride 1. int64_t cur_contig_stride = 1; bool still_rightmost = true; - for (int64_t i = (int64_t)sizes_strides.size() - 1; i >= 0; --i) { - const auto [size, stride] = sizes_strides.at(i); + for (int64_t i = (int64_t)sizes.size() - 1; i >= 0; --i) { + const auto size = sizes.at(i); + const auto stride = strides.at(i); auto alloc_id = tv->getMaybeAllocationDomain().at(no_reduction_to_full.at(i)); const auto is_expanded_broadcasting = @@ -673,7 +678,9 @@ void validateAlignedVectorizedFusionInputOutput( " Domain: ", tv->axis(i)->toString(), ", stride: ", - stride) + stride, + ", cur_contig_stride ", + cur_contig_stride); // If the domain is size-1, the next domain is still considered // rightmost. still_rightmost = @@ -717,14 +724,15 @@ void validateAlignedVectorizedTensors( dynamic_cast(args[pos]); TORCH_INTERNAL_ASSERT(tensor_arg_abstract, "alias io only supports tensor"); validateAlignedVectorizedFusionInputOutput( - tensor_arg_abstract->getTensor(), word_size, tv); + tensor_arg_abstract->getTensor(), word_size, tv, expr_eval); } if (!outputs.empty()) { for (auto pos : tensor_vectorization_validation_entry.get() .aligned_vectorized_out_tensor_pos) { auto tv = kernel->outputs().at(pos)->as(); auto word_size = kernel->summary().vectorized_accesses.at(tv); - validateAlignedVectorizedFusionInputOutput(outputs[pos], word_size, tv); + validateAlignedVectorizedFusionInputOutput( + outputs[pos], word_size, tv, expr_eval); } } } diff --git a/csrc/index_compute.cpp b/csrc/index_compute.cpp index 01b1c4159e2..ff0c86153f5 100644 --- a/csrc/index_compute.cpp +++ b/csrc/index_compute.cpp @@ -1402,7 +1402,7 @@ std::vector Index::getGlobalProducerStridedIndices( } strides[i] = IrBuilder::getItemExpr( IrBuilder::getAttrExpr( - IrBuilder::metadataExpr(producer_tv), "stride"), + IrBuilder::metadataExpr(producer_tv), "alloc_stride"), stride_i++); } } @@ -1758,7 +1758,7 @@ std::vector Index::getStrides(TensorView* tv) { continue; } strides[i] = IrBuilder::getItemExpr( - IrBuilder::getAttrExpr(IrBuilder::metadataExpr(tv), "stride"), + IrBuilder::getAttrExpr(IrBuilder::metadataExpr(tv), "alloc_stride"), stride_i++); } } diff --git a/csrc/ir/nodes.cpp b/csrc/ir/nodes.cpp index c1635d15eda..e81c49e5bf2 100644 --- a/csrc/ir/nodes.cpp +++ b/csrc/ir/nodes.cpp @@ -767,33 +767,6 @@ std::string GetMetaData::toInlineString(int indent_size) const { return ss.str(); } -std::vector GetMetaData::evaluate( - const ExpressionEvaluator& ee, - const std::vector& inputs) const { - TORCH_INTERNAL_ASSERT(inputs.size() == 1, "GetMetaData expects 1 input"); - TORCH_INTERNAL_ASSERT( - in()->isA(), - "Currently, GetMetaData only supports TensorView"); - TensorView* tv = in()->as(); - if (tv->getMemoryType() == MemoryType::Shared) { - // Smem tensor is defined locally as a pointer. It is impossible to know the - // actual address, but using nullptr is a good approximation. - return {PolymorphicValue(Pointer(nullptr, tv->dtype()))}; - } - - at::Tensor input = inputs.at(0).as(); - - Struct concrete_value; - concrete_value["data"] = - PolymorphicValue(Pointer(input.data_ptr(), tv->dtype())); - concrete_value["size"] = PolymorphicValue(input.sizes().vec()); - // TODO: this is not correct, strides actually needs to be based on allocation - // domain, but input.strides() is on the rFactor domain. We need to refactor - // our executor to move related logic here. - concrete_value["stride"] = PolymorphicValue(input.strides().vec()); - return {PolymorphicValue(concrete_value)}; -} - NVFUSER_DEFINE_CLONE_AND_CREATE(GetMetaData) TensorConstruct::TensorConstruct( @@ -3648,7 +3621,7 @@ bool NamedScalar::sameAs(const Statement* other) const { } bool NamedScalar::isTensorSize() const { - static const std::regex r(R"(T\d+\.size\[\d+\])"); + static const std::regex r(R"(T\d+\.\w*size\[\d+\])"); return std::regex_match(name(), r); } diff --git a/csrc/ir/utils.cpp b/csrc/ir/utils.cpp index 46590eb3864..9ee0257f3da 100644 --- a/csrc/ir/utils.cpp +++ b/csrc/ir/utils.cpp @@ -1149,11 +1149,12 @@ bool isTensorSize(const Val* val) { return true; } } - return isTensorAttr(val, "size"); + return isTensorAttr(val, "logical_size") || isTensorAttr(val, "alloc_size"); } bool isTensorStride(const Val* val) { - return isTensorAttr(val, "stride"); + return isTensorAttr(val, "logical_stride") || + isTensorAttr(val, "alloc_stride"); } } // namespace nvfuser::ir_utils diff --git a/csrc/polymorphic_value.h b/csrc/polymorphic_value.h index 01eaf7fa2be..a8747ca20c3 100644 --- a/csrc/polymorphic_value.h +++ b/csrc/polymorphic_value.h @@ -168,6 +168,10 @@ class Pointer { explicit operator unsigned() const { return (unsigned)(int64_t)(*this); } + + explicit operator size_t() const { + return reinterpret_cast(ptr_); + } }; inline Pointer operator+(int64_t offset, const Pointer& ptr) { diff --git a/csrc/scheduler/registry.cpp b/csrc/scheduler/registry.cpp index 92dc047d50d..e56ce83c4bf 100644 --- a/csrc/scheduler/registry.cpp +++ b/csrc/scheduler/registry.cpp @@ -971,34 +971,32 @@ SchedulerRuntimeInfo::SchedulerRuntimeInfo( *expression_evaluator_); } - // Convert all abstract tensor args into tensor args and do tensor stride - // inference - std::vector tvs; - tvs.reserve(complete_fusion_->inputs().size()); - for (auto val : complete_fusion_->inputs()) { - tvs.emplace_back(dynamic_cast(val)); - } - args.getBuffer(index_type_, tvs, *expression_evaluator_); - for (auto inp_i : c10::irange(static_cast(args.size()))) { - auto kernel_arg = args[inp_i]; + auto fusion_inp = complete_fusion_->inputs().at(inp_i); + auto input_tv = dynamic_cast(fusion_inp); // Note: we are skipping CpuScalar tensor here - if (auto tensor_arg_abstract = - dynamic_cast(kernel_arg)) { - auto fusion_inp = complete_fusion_->inputs()[inp_i]; - input_ptrs_[fusion_inp] = tensor_arg_abstract->getPointerAddress(); + if (input_tv != nullptr && !input_tv->isCpuScalar()) { + auto metadata = + expression_evaluator_->evaluate(IrBuilder::metadataExpr(input_tv)); + std::vector alloc_sizes = + (std::vector)metadata["alloc_size"]; + std::vector alloc_strides = + (std::vector)metadata["alloc_stride"]; + TORCH_INTERNAL_ASSERT(alloc_sizes.size() == alloc_strides.size()); + + input_ptrs_[fusion_inp] = (size_t)metadata["data"]; // find and push discontiguous stride - auto dtype_size = dataTypeSize(tensor_arg_abstract->getDataType()); + int64_t dtype_size = dataTypeSize(input_tv->dtype()); input_discontig_strides_[fusion_inp] = {}; - auto dims = tensor_arg_abstract->getAllocRank(); + int64_t dims = (int64_t)alloc_strides.size(); int64_t expected_stride = 1; - for (auto dim = dims - 1; dim >= 0; dim--) { - auto size = tensor_arg_abstract->getAllocSize((int)dim); + for (int64_t dim = dims - 1; dim >= 0; dim--) { + auto size = alloc_sizes.at(dim); if (size <= 1) { continue; } - auto stride = tensor_arg_abstract->getAllocStride((int)dim); + auto stride = alloc_strides.at(dim); if (stride != expected_stride) { input_discontig_strides_[fusion_inp].push_back(stride * dtype_size); expected_stride = stride; @@ -1030,11 +1028,10 @@ std::unique_ptr SchedulerRuntimeInfo:: const KernelArgumentHolder& args, PrecomputedValues* precomputed_values) { std::unique_ptr ee = - std::make_unique(); + std::make_unique( + executor_utils::bindInputs(args, complete_fusion_)); if (precomputed_values) { ee->bindPrecomputedValues(precomputed_values); - } else { - *ee = executor_utils::bindInputs(args, complete_fusion_); } return ee; } diff --git a/csrc/tensor_metadata.cpp b/csrc/tensor_metadata.cpp new file mode 100644 index 00000000000..e0ff07dede4 --- /dev/null +++ b/csrc/tensor_metadata.cpp @@ -0,0 +1,348 @@ +// clang-format off +/* + * SPDX-FileCopyrightText: Copyright (c) 2023-present NVIDIA CORPORATION & AFFILIATES. + * All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + */ +// clang-format on + +#include +#include +#include +#include +#include +#include +#include + +namespace nvfuser { + +namespace { + +// Forward traverse from rFactor domain to allocation domain, compute frontier +// sizes and strides, validate that splits are divisible and merges are +// contiguous, and update active_ids_ correspondingly. +class ForwardTraverseFromRFactorToAlloc { + ExpressionEvaluator& ee_; + std::unordered_map>& active_ids_; + + void handle(Split* split) { + auto in = split->in(); + auto inner = split->inner(); + auto outer = split->outer(); + auto in_it = active_ids_.find(in); + // TORCH_INTERNAL_ASSERT(in_it != active_ids_.end()) + if (in_it == active_ids_.end()) { + // TODO: see [Allocation domain on both side of rFactor] + return; + } + auto [in_size, in_stride] = in_it->second; + auto factor = ee_.evaluate(split->factor()).as(); + TORCH_INTERNAL_ASSERT( + in_size % factor == 0, + "The rFactor domain and allocation domain of fusion input/output ", + "tensors must be a one-to-one map, therefore, ", + "non-divisible split is not allowed in allocation domain"); + TORCH_INTERNAL_ASSERT(active_ids_.erase(in) == 1); + TORCH_INTERNAL_ASSERT( + active_ids_ + .emplace(inner, std::pair{factor, in_stride}) + .second); + TORCH_INTERNAL_ASSERT(active_ids_ + .emplace( + outer, + std::pair{ + in_size / factor, in_stride * factor}) + .second); + } + + void handle(Merge* merge) { + auto inner = merge->inner(); + auto outer = merge->outer(); + auto out = merge->out(); + auto inner_it = active_ids_.find(inner); + auto outer_it = active_ids_.find(outer); + // TORCH_INTERNAL_ASSERT(inner_it != active_ids_.end()) + // TORCH_INTERNAL_ASSERT(outer_it != active_ids_.end()) + if (inner_it == active_ids_.end() || outer_it == active_ids_.end()) { + // TODO: see [Allocation domain on both side of rFactor] + return; + } + auto [inner_size, inner_stride] = inner_it->second; + auto [outer_size, outer_stride] = outer_it->second; + TORCH_INTERNAL_ASSERT( + inner_stride * inner_size == outer_stride, + "The rFactor domain and allocation domain of fusion input/output ", + "tensors must be a one-to-one map, therefore, ", + "merging of discontiguous dimensions is not allowed in allocation domain"); + TORCH_INTERNAL_ASSERT(active_ids_.erase(inner) == 1); + TORCH_INTERNAL_ASSERT(active_ids_.erase(outer) == 1); + TORCH_INTERNAL_ASSERT(active_ids_ + .emplace( + out, + std::pair{ + inner_size * outer_size, inner_stride}) + .second); + } + + void handle(Expr* expr) { + if (auto split = dynamic_cast(expr)) { + handle(split); + } else if (auto merge = dynamic_cast(expr)) { + handle(merge); + } else { + TORCH_INTERNAL_ASSERT( + false, "Unsupported transormation in allocation domain"); + } + } + + public: + ForwardTraverseFromRFactorToAlloc( + ExpressionEvaluator& ee, + std::unordered_map>& active_ids) + : ee_(ee), active_ids_(active_ids) {} + + void run( + TensorView* tv, + const std::vector& rfactor, + const std::vector& alloc) { + auto forward_exprs = StmtSort::getExprsBetween( + tv->fusion(), + {rfactor.begin(), rfactor.end()}, + {alloc.begin(), alloc.end()}); + for (auto expr : forward_exprs) { + handle(expr); + } + } +}; + +// Similar to ForwardTraverseFromRFactorToAlloc, but in the opposite direction. +class BackwardTraverseFromRFactorToAlloc { + at::Tensor tensor_; + ExpressionEvaluator& ee_; + std::unordered_map>& active_ids_; + + void handle(Split* split) { + auto in = split->in(); + auto inner = split->inner(); + auto outer = split->outer(); + auto inner_it = active_ids_.find(inner); + auto outer_it = active_ids_.find(outer); + // TORCH_INTERNAL_ASSERT(inner_it != active_ids_.end()) + // TORCH_INTERNAL_ASSERT(outer_it != active_ids_.end()) + if (inner_it == active_ids_.end() || outer_it == active_ids_.end()) { + // TODO: see [Allocation domain on both side of rFactor] + return; + } + auto [inner_size, inner_stride] = inner_it->second; + auto [outer_size, outer_stride] = outer_it->second; + TORCH_INTERNAL_ASSERT( + inner_stride * inner_size == outer_stride, + "The rFactor domain and allocation domain of fusion input/output ", + "tensors must be a one-to-one map, therefore, ", + "splitting one dimension into discontiguous dimensions is not allowed in allocation domain"); + TORCH_INTERNAL_ASSERT(active_ids_.erase(inner) == 1); + TORCH_INTERNAL_ASSERT(active_ids_.erase(outer) == 1); + TORCH_INTERNAL_ASSERT(active_ids_ + .emplace( + in, + std::pair{ + inner_size * outer_size, inner_stride}) + .second); + } + + void handle(Merge* merge) { + auto inner = merge->inner(); + auto outer = merge->outer(); + auto out = merge->out(); + auto factor = ee_.evaluate(inner->extent()).as(); + auto out_it = active_ids_.find(out); + // TORCH_INTERNAL_ASSERT(out_it != active_ids_.end()) + if (out_it == active_ids_.end()) { + // TODO: see [Allocation domain on both side of rFactor] + return; + } + auto [out_size, out_stride] = out_it->second; + TORCH_INTERNAL_ASSERT( + out_size % factor == 0, + "The rFactor domain and allocation domain of fusion input/output ", + "tensors must be a one-to-one map, therefore, ", + "the size of the output must divisible by the size of inner dimension"); + TORCH_INTERNAL_ASSERT(active_ids_.erase(out) == 1); + TORCH_INTERNAL_ASSERT( + active_ids_ + .emplace(inner, std::pair{factor, out_stride}) + .second); + TORCH_INTERNAL_ASSERT(active_ids_ + .emplace( + outer, + std::pair{ + out_size / factor, out_stride * factor}) + .second); + } + + void handle(Expr* expr) { + if (auto split = dynamic_cast(expr)) { + handle(split); + } else if (auto merge = dynamic_cast(expr)) { + handle(merge); + } else { + TORCH_INTERNAL_ASSERT( + false, "Unsupported transormation in allocation domain"); + } + } + + public: + BackwardTraverseFromRFactorToAlloc( + ExpressionEvaluator& ee, + std::unordered_map>& active_ids) + : ee_(ee), active_ids_(active_ids) {} + + void run( + TensorView* tv, + const std::vector& rfactor, + const std::vector& alloc) { + auto backward_exprs = StmtSort::getExprsBetween( + tv->fusion(), + {alloc.begin(), alloc.end()}, + {rfactor.begin(), rfactor.end()}); + std::reverse(backward_exprs.begin(), backward_exprs.end()); + for (auto expr : backward_exprs) { + handle(expr); + } + } +}; + +// Given an ATen tensor, whose sizes and strides are w.r.t to the rFactor domain +// of its corresponding TensorView, compute the sizes and strides of the tensor +// with respect to its allocation domain. +// For example, if the rFactor domain is [I1, I2], and the allocation domain is +// [I2*I1], and the tensor's size is [5, 3] and stride is [2, 10], then the +// resulting size will be [15] and stride will be [2] +// Another example, if the rFactor domain is [I1*I2] and the allocation domain +// is [I1, I2], and the tensor's size is [15] and stride is [7], and the extent +// of I2 is 5, then the resulting size will be [3, 5] and stride will be [35, 7] +std::pair, std::vector> +inferAndValidateAllocationSizesAndStrides( + const at::Tensor& tensor, + TensorView* tv, + ExpressionEvaluator ee) { + if (tv == nullptr || !tv->hasAllocation()) { + // When tv is nullptr, or tv does not have allocation, the given sizes and + // strides should already be in the target format. So nothing to do here. + std::vector sizes; + std::vector strides; + for (auto i : c10::irange(tensor.dim())) { + sizes.emplace_back(tensor.size(i)); + strides.emplace_back(tensor.stride(i)); + } + return {sizes, strides}; + } + const auto& alloc = + TensorDomain::noReductions(tv->getMaybeAllocationDomain()); + const auto& rfactor = TensorDomain::noReductions(tv->getMaybeRFactorDomain()); + + // active IDs and their shape and stride + std::unordered_map> active_ids; + TORCH_INTERNAL_ASSERT((int64_t)rfactor.size() == tensor.dim()); + for (int64_t i : c10::irange((int64_t)rfactor.size())) { + auto rf_id = rfactor.at(i); + active_ids[rf_id] = {tensor.size(i), tensor.stride(i)}; + } + + ForwardTraverseFromRFactorToAlloc(ee, active_ids).run(tv, rfactor, alloc); + BackwardTraverseFromRFactorToAlloc(ee, active_ids).run(tv, rfactor, alloc); + + // Now active_ids should contain the final sizes and strides, unordered. We + // need to put them to the correct order. + std::vector sizes; + std::vector strides; + sizes.reserve(alloc.size()); + strides.reserve(alloc.size()); + for (auto i : c10::irange(alloc.size())) { + auto id = alloc.at(i); + sizes.emplace_back(active_ids.at(id).first); + strides.emplace_back(active_ids.at(id).second); + } + // Validate final sizes and strides with contiguity + int64_t contiguous_stride = 1; + std::vector> contiguity = tv->getContiguity(); + for (int64_t i = (int64_t)sizes.size() - 1; i >= 0; i--) { + if (alloc.at(i)->isBroadcast()) { + continue; + } + while (!contiguity.back().has_value()) { + contiguity.pop_back(); + } + auto size = sizes.at(i); + auto stride = strides.at(i); + TORCH_INTERNAL_ASSERT(!contiguity.empty()); + auto last_contiguity = contiguity.back(); + TORCH_INTERNAL_ASSERT( + last_contiguity.has_value(), + "I don't think this check makes sense, but unfortunately ", + "clang-tidy is not smart enough to infer from the context that this is always true."); + if (*last_contiguity) { + TORCH_CHECK( + stride == contiguous_stride, + "Stride mismatch with contiguity info. ", + "tv: ", + tv->toString(), + " allocation domain: ", + ir_utils::toString(tv->getMaybeAllocationDomain()), + " dim: ", + i, + " expected stride: ", + contiguous_stride, + " actual stride: ", + stride); + } + contiguous_stride = stride * size; + contiguity.pop_back(); + } + TORCH_INTERNAL_ASSERT( + contiguity.empty(), + "The size of contiguity mismatch with the dimensionality of allocation domain"); + // Validate that for expanded broadcast, the stride must be zero. + for (int64_t i : c10::irange((int64_t)strides.size())) { + if (auto alloc_id = alloc.at(i); alloc_id->hasExpandedExtent()) { + auto stride = strides.at(i); + TORCH_CHECK( + stride == 0, + "Expecting an expanded dimension on dimension ", + i, + " but found stride ", + stride); + } + } + return {sizes, strides}; +} + +} // namespace + +std::vector GetMetaData::evaluate( + const ExpressionEvaluator& ee, + const std::vector& inputs) const { + TORCH_INTERNAL_ASSERT(inputs.size() == 1, "GetMetaData expects 1 input"); + TORCH_INTERNAL_ASSERT( + in()->isA(), + "Currently, GetMetaData only supports TensorView"); + TensorView* tv = in()->as(); + if (tv->getMemoryType() == MemoryType::Shared) { + // Smem tensor is defined locally as a pointer. It is impossible to know the + // actual address, but using nullptr is a good approximation. + return {PolymorphicValue(Pointer(nullptr, tv->dtype()))}; + } + + at::Tensor input = inputs.at(0).as(); + + Struct concrete_value; + concrete_value["data"] = + PolymorphicValue(Pointer(input.data_ptr(), tv->dtype())); + concrete_value["logical_size"] = PolymorphicValue(input.sizes().vec()); + concrete_value["logical_stride"] = PolymorphicValue(input.strides().vec()); + std::tie(concrete_value["alloc_size"], concrete_value["alloc_stride"]) = + inferAndValidateAllocationSizesAndStrides(input, tv, ee); + return {PolymorphicValue(concrete_value)}; +} + +} // namespace nvfuser diff --git a/csrc/type.cpp b/csrc/type.cpp index 4aef763e1d1..abe87a643e7 100644 --- a/csrc/type.cpp +++ b/csrc/type.cpp @@ -35,11 +35,16 @@ DataType metaDataTypeOf(const Val* v) { StructOf tv_metadata; tv_metadata.name = ss.str(); + tv_metadata.field_names = {"data", "logical_size", "alloc_stride"}; tv_metadata.types["data"] = NVFUSER_MAYBE_MAKE_SHARED( PointerOf{std::make_shared(tv->dtype())}); - tv_metadata.types["size"] = NVFUSER_MAYBE_MAKE_SHARED2( + tv_metadata.types["logical_size"] = NVFUSER_MAYBE_MAKE_SHARED2( ArrayOf{std::make_shared(DataType::Index), dim}); - tv_metadata.types["stride"] = NVFUSER_MAYBE_MAKE_SHARED2( + tv_metadata.types["logical_stride"] = NVFUSER_MAYBE_MAKE_SHARED2( + ArrayOf{std::make_shared(DataType::Index), dim}); + tv_metadata.types["alloc_size"] = NVFUSER_MAYBE_MAKE_SHARED2( + ArrayOf{std::make_shared(DataType::Index), alloc_dim}); + tv_metadata.types["alloc_stride"] = NVFUSER_MAYBE_MAKE_SHARED2( ArrayOf{std::make_shared(DataType::Index), alloc_dim}); return tv_metadata; } @@ -211,9 +216,9 @@ static std::string data_type2string(DataType t) { } std::stringstream ss; ss << "struct { "; - for (auto& [name, type] : dtype.types) { - ss << data_type2string(NVFUSER_MAYBE_STAR type) << " " << name - << "; "; + for (auto& name : dtype.field_names) { + ss << data_type2string(NVFUSER_MAYBE_STAR dtype.types.at(name)) + << " " << name << "; "; } ss << "}"; return ss.str(); diff --git a/csrc/type.h b/csrc/type.h index 84bc55e9372..549ff864263 100644 --- a/csrc/type.h +++ b/csrc/type.h @@ -114,6 +114,11 @@ struct StructOf { // runtime/, and anonymous structs for others. std::string name; + // The ordered list of field names. This is used to generate the struct type + // on device. This list does not necessarily contain all the fields in the + // struct, but it should contain all the fields that are used on device. + std::vector field_names; + // Note [Incomplete type support in STL] // std::unordered_map is a STL container of incomplete // type. Not all C++ STL containers supports incomplete type due to historical diff --git a/runtime/tensor.cu b/runtime/tensor.cu index 8d5648bb28d..c0ee5436cf7 100644 --- a/runtime/tensor.cu +++ b/runtime/tensor.cu @@ -12,8 +12,8 @@ struct Tensor { }; T* data; - Array size; - Array stride; + Array logical_size; + Array alloc_stride; }; // Specialization for 0-dim case as it does not need size and stride arrays. diff --git a/test/test_evaluator.cpp b/test/test_evaluator.cpp index ebd89b3aeca..a725d25513e 100644 --- a/test/test_evaluator.cpp +++ b/test/test_evaluator.cpp @@ -318,8 +318,8 @@ TEST_F(ExprEvalTest, TensorMetaData) { TensorView* tv = makeSymbolicTensor(2); auto metadata = IrBuilder::metadataExpr(tv); auto data = IrBuilder::getAttrExpr(metadata, "data"); - auto sizes = IrBuilder::getAttrExpr(metadata, "size"); - auto strides = IrBuilder::getAttrExpr(metadata, "stride"); + auto sizes = IrBuilder::getAttrExpr(metadata, "logical_size"); + auto strides = IrBuilder::getAttrExpr(metadata, "alloc_stride"); auto size0 = IrBuilder::getItemExpr(sizes, fusion.zeroVal()); auto size1 = IrBuilder::getItemExpr(sizes, fusion.oneVal()); auto stride0 = IrBuilder::getItemExpr(strides, fusion.zeroVal()); diff --git a/test/test_expr_simplifier.cpp b/test/test_expr_simplifier.cpp index 7887f400ed9..d0a4260a277 100644 --- a/test/test_expr_simplifier.cpp +++ b/test/test_expr_simplifier.cpp @@ -699,8 +699,8 @@ TEST_F(ExprSimplifierTest, SignProve) { assertProvedNonZero("1"_); assertProvedNonZero("2"_); - assertProvedNonNegative("T123.size[3]"_); - assertProvedNonNegative("T123.stride[3]"_); + assertProvedNonNegative("T123.logical_size[3]"_); + assertProvedNonNegative("T123.alloc_stride[3]"_); std::vector assumptions{ "i1 < 2 && i1 >= 0"_, @@ -780,18 +780,18 @@ TEST_F(ExprSimplifierTest, DistributeGcdRemainderDivMod) { expectSimplifiedMod("i1 * 3 + 2"_, "6"_, "( i1 % 2 ) * 3 + 2"_, {"i1 >= 0"_}); expectSimplifiedDiv( "i1 * 4 + 3"_, - "32 * T0.size[0]"_, - "i1 / ( 8 * T0.size[0] )"_, + "32 * T0.logical_size[0]"_, + "i1 / ( 8 * T0.logical_size[0] )"_, {"i1 >= 0"_}); expectSimplifiedMod( "i1 * 4 + 3"_, - "32 * T0.size[0]"_, - "( i1 % ( 8 * T0.size[0] ) ) * 4 + 3"_, + "32 * T0.logical_size[0]"_, + "( i1 % ( 8 * T0.logical_size[0] ) ) * 4 + 3"_, {"i1 >= 0"_}); expectSimplifiedDiv( - "( ( ( blockIdx.x * 128 + threadIdx.x ) % ( T0.size[3] * 24 ) ) * 4 ) + 3"_, - "32 * T0.size[3]"_, - "( ( blockIdx.x * 128 + threadIdx.x ) % ( T0.size[3] * 24 ) ) / ( 8 * T0.size[3] )"_, + "( ( ( blockIdx.x * 128 + threadIdx.x ) % ( T0.logical_size[3] * 24 ) ) * 4 ) + 3"_, + "32 * T0.logical_size[3]"_, + "( ( blockIdx.x * 128 + threadIdx.x ) % ( T0.logical_size[3] * 24 ) ) / ( 8 * T0.logical_size[3] )"_, {}); } @@ -838,30 +838,32 @@ TEST_F(ExprSimplifierTest, Compare) { EXPECT_TRUE(*simplify("d1 >= d1 * d2"_, "d1 <= 0.0 && d2 >= 1.0"_)); EXPECT_TRUE( *simplifyExpr( - "ceilDiv( T0.size[0] , 128 ) * 4 >= ceilDiv( T0.size[0] , 128 )"_) + "ceilDiv( T0.logical_size[0] , 128 ) * 4 >= ceilDiv( T0.logical_size[0] , 128 )"_) ->getBool()); EXPECT_TRUE(*simplify("ceilDiv( i1 , i2 ) > 0"_, "i1 > 0 && i2 > 0"_)); EXPECT_TRUE(*simplify("ceilDiv( i1 , i2 ) >= 1"_, "i1 > 0 && i2 > 0"_)); EXPECT_TRUE(*simplify( - "blockIdx.x < ceilDiv( T0.size[0] , 128 ) * 4"_, - "blockIdx.x < ceilDiv( T0.size[0] , 128 ) * 4"_)); + "blockIdx.x < ceilDiv( T0.logical_size[0] , 128 ) * 4"_, + "blockIdx.x < ceilDiv( T0.logical_size[0] , 128 ) * 4"_)); EXPECT_TRUE(*simplify("i1 % i2 < i2"_, "i2 >= 0"_)); } TEST_F(ExprSimplifierTest, FundamentalDivisionWithRemainderProperty) { - EXPECT_TRUE( - isEquivalent("i1 / T1.size[0] * T1.size[0] + i1 % T1.size[0]"_, "i1"_)); EXPECT_TRUE(isEquivalent( - "( i2 + i1 / T1.size[0] * T1.size[0] ) + i1 % T1.size[0]"_, "i1 + i2"_)); + "i1 / T1.logical_size[0] * T1.logical_size[0] + i1 % T1.logical_size[0]"_, + "i1"_)); + EXPECT_TRUE(isEquivalent( + "( i2 + i1 / T1.logical_size[0] * T1.logical_size[0] ) + i1 % T1.logical_size[0]"_, + "i1 + i2"_)); EXPECT_TRUE(isEquivalent( - "( i1 / T1.size[0] ) * ( T1.size[0] * T1.size[1] ) + T1.size[1] * ( i1 % T1.size[0] )"_, - "i1 * T1.size[1]"_)); + "( i1 / T1.logical_size[0] ) * ( T1.logical_size[0] * T1.logical_size[1] ) + T1.logical_size[1] * ( i1 % T1.logical_size[0] )"_, + "i1 * T1.logical_size[1]"_)); EXPECT_TRUE(isEquivalent( - "i2 + ( i1 / T1.size[0] ) * ( T1.size[0] * T1.size[1] ) + T1.size[1] * ( i1 % T1.size[0] )"_, - "i1 * T1.size[1] + i2"_)); + "i2 + ( i1 / T1.logical_size[0] ) * ( T1.logical_size[0] * T1.logical_size[1] ) + T1.logical_size[1] * ( i1 % T1.logical_size[0] )"_, + "i1 * T1.logical_size[1] + i2"_)); } TEST_F(ExprSimplifierTest, ReducePredicateRegisterUsage) { @@ -1030,14 +1032,15 @@ TEST_F(ExprSimplifierTest, MinMax) { }; auto expr = - "max( max( ceilDiv( T0.size[0] , 128 ) * 4 , ceilDiv( T0.size[0] , 128 ) ) , 4 )"_; - EXPECT_TRUE(simplify(expr, "T0.size[0] > 0"_) - ->sameAs("ceilDiv( T0.size[0] , 128 ) * 4"_)); + "max( max( ceilDiv( T0.logical_size[0] , 128 ) * 4 , ceilDiv( T0.logical_size[0] , 128 ) ) , 4 )"_; + EXPECT_TRUE(simplify(expr, "T0.logical_size[0] > 0"_) + ->sameAs("ceilDiv( T0.logical_size[0] , 128 ) * 4"_)); } TEST_F(ExprSimplifierTest, PredicateDivToMul) { - auto simplified = simplifyExpr("i1 / T0.size[0] < i2"_, {}, {"i1 >= 0"_}); - auto expect = "i1 < ( i2 * T0.size[0] )"_; + auto simplified = + simplifyExpr("i1 / T0.logical_size[0] < i2"_, {}, {"i1 >= 0"_}); + auto expect = "i1 < ( i2 * T0.logical_size[0] )"_; EXPECT_TRUE(simplified->sameAs(expect)); } diff --git a/test/test_gpu1.cpp b/test/test_gpu1.cpp index 188bd66e524..b78a60c4a00 100644 --- a/test/test_gpu1.cpp +++ b/test/test_gpu1.cpp @@ -892,7 +892,7 @@ TEST_F(NVFuserTest, FusionParser_CUDA) { __global__ void CUDAGeneratedKernel(Tensor T0, Tensor T1, Tensor T3) { int64_t i0; i0 = ((nvfuser_index_t)threadIdx.x) + (128 * ((nvfuser_index_t)blockIdx.x)); - if ((i0 < T0.size[0])) { + if ((i0 < T0.logical_size[0])) { float T5[1]; T5[0] = 0; T5[0] diff --git a/test/test_gpu2.cpp b/test/test_gpu2.cpp index c5f84e7b3fd..73985284b13 100644 --- a/test/test_gpu2.cpp +++ b/test/test_gpu2.cpp @@ -2255,8 +2255,8 @@ TEST_F(NVFuserTest, FusionSimpleCompileRtc_CUDA) { std::string kernel = R"( __global__ void kernel1(Tensor T0, Tensor T1) { if(threadIdx.x==0){ - for(size_t ki28 = 0; ki28 < T0.size[0]; ++ki28) { - T1[ki28*T1.stride[0]] = T0[ki28*T0.stride[0]]*2; + for(size_t ki28 = 0; ki28 < T0.logical_size[0]; ++ki28) { + T1[ki28*T1.alloc_stride[0]] = T0[ki28*T0.alloc_stride[0]]*2; } } } @@ -2292,27 +2292,27 @@ __global__ void kernel1( Tensor out_var, Tensor out_avg ){ - for(int i0=0;i0 T0, Tensor<__half, 4, 4> T2, Tensor<__half, 4, 4> T7) { int64_t i0; - i0 = T0.size[2] * T0.size[1]; + i0 = T0.logical_size[2] * T0.logical_size[1]; int64_t i1; i1 = ((nvfuser_index_t)threadIdx.x) + (128 * ((nvfuser_index_t)blockIdx.x)); int64_t i2; - i2 = (T0.size[1] * T0.size[2]) * T0.size[3]; + i2 = (T0.logical_size[1] * T0.logical_size[2]) * T0.logical_size[3]; int64_t i3; i3 = i1 % i2; int64_t i4; - i4 = T0.size[2] * T0.size[3]; + i4 = T0.logical_size[2] * T0.logical_size[3]; int64_t i5; i5 = i3 % i4; - if ((i1 < (((T0.size[0] * T0.size[1]) * T0.size[2]) * T0.size[3]))) { + if ((i1 < (((T0.logical_size[0] * T0.logical_size[1]) * T0.logical_size[2]) * T0.logical_size[3]))) { __half T9[1]; T9[0] = 0; T9[0] - = T2[(((((i0 * T0.size[3]) * (i1 / i2)) + (i0 * (i5 % T0.size[3]))) + (T0.size[2] * (i3 / i4))) + (i5 / T0.size[3]))]; + = T2[(((((i0 * T0.logical_size[3]) * (i1 / i2)) + (i0 * (i5 % T0.logical_size[3]))) + (T0.logical_size[2] * (i3 / i4))) + (i5 / T0.logical_size[3]))]; __half T8[1]; T8[0] = 0; T8[0] diff --git a/test/test_gpu3.cpp b/test/test_gpu3.cpp index 668cac4e426..3a5fa452ad5 100644 --- a/test/test_gpu3.cpp +++ b/test/test_gpu3.cpp @@ -1731,7 +1731,7 @@ __global__ void CUDAGeneratedKernel(Tensor T0, Tensor int64_t i0; i0 = ((nvfuser_index_t)threadIdx.x) + (256 * ((nvfuser_index_t)blockIdx.x)); int64_t i1; - i1 = T0.size[0] * T0.size[1]; + i1 = T0.logical_size[0] * T0.logical_size[1]; bool b2; b2 = i0 < i1; float f3; diff --git a/test/test_loop_rotation.cpp b/test/test_loop_rotation.cpp index 5c71941fc91..b5b0d635b03 100644 --- a/test/test_loop_rotation.cpp +++ b/test/test_loop_rotation.cpp @@ -36,13 +36,13 @@ TEST_F(LoopRotationTest, RotateInner) { __global__ void CUDAGeneratedKernel(Tensor T0, Tensor T4) { NVFUSER_DEFINE_MAGIC_ZERO; Array a0; - a0 = (T0).stride; + a0 = (T0).alloc_stride; int64_t i1; i1 = a0[0]; int64_t i2; i2 = a0[1]; #pragma unroll 1 - for(nvfuser_index_t i3 = 0; i3 < T0.size[0]; ++i3) { + for(nvfuser_index_t i3 = 0; i3 < T0.logical_size[0]; ++i3) { int64_t i4; i4 = i1 * i3; int64_t i5; @@ -108,7 +108,7 @@ TEST_F(LoopRotationTest, RotateOuter) { __global__ void CUDAGeneratedKernel(Tensor T0, Tensor T4) { NVFUSER_DEFINE_MAGIC_ZERO; Array a0; - a0 = (T0).stride; + a0 = (T0).alloc_stride; int64_t i1; i1 = a0[1]; int64_t i2; @@ -133,13 +133,13 @@ __global__ void CUDAGeneratedKernel(Tensor T0, Tensor } NVFUSER_UPDATE_MAGIC_ZERO; #pragma unroll 1 - for(nvfuser_index_t i5 = 0; i5 < T0.size[0]; ++i5) { + for(nvfuser_index_t i5 = 0; i5 < T0.logical_size[0]; ++i5) { int64_t i6; i6 = 3 * i5; int64_t i7; i7 = i2 + (i2 * i5); bool b8; - b8 = (1 + i5) < T0.size[0]; + b8 = (1 + i5) < T0.logical_size[0]; // Alias Allocation - register auto& T3 = T1; #pragma unroll @@ -211,13 +211,13 @@ TEST_F(LoopRotationTest, NonDivisibleSplit) { __global__ void CUDAGeneratedKernel(Tensor T0, Tensor T4) { NVFUSER_DEFINE_MAGIC_ZERO; Array a0; - a0 = (T0).stride; + a0 = (T0).alloc_stride; int64_t i1; i1 = a0[0]; int64_t i2; i2 = a0[1]; int64_t i3; - i3 = T0.size[0] * T0.size[1]; + i3 = T0.logical_size[0] * T0.logical_size[1]; float T1[5]; float T2[5]; #pragma unroll @@ -231,7 +231,7 @@ __global__ void CUDAGeneratedKernel(Tensor T0, Tensor i5 = i4 + nvfuser_zero; if ((i5 < i3)) { T1[i4] - = T0[((i1 * (i5 / T0.size[1])) + (i2 * (i5 % T0.size[1])))]; + = T0[((i1 * (i5 / T0.logical_size[1])) + (i2 * (i5 % T0.logical_size[1])))]; } } NVFUSER_UPDATE_MAGIC_ZERO; @@ -242,7 +242,7 @@ __global__ void CUDAGeneratedKernel(Tensor T0, Tensor } NVFUSER_UPDATE_MAGIC_ZERO; #pragma unroll 1 - for(nvfuser_index_t i7 = 0; i7 < (ceilDiv((T0.size[0] * T0.size[1]), 5)); ++i7) { + for(nvfuser_index_t i7 = 0; i7 < (ceilDiv((T0.logical_size[0] * T0.logical_size[1]), 5)); ++i7) { int64_t i8; i8 = 5 * i7; int64_t i9; @@ -276,7 +276,7 @@ __global__ void CUDAGeneratedKernel(Tensor T0, Tensor i13 = i9 + (i4 + nvfuser_zero); if ((i13 < i3)) { T1[i4] - = T0[((i1 * (i13 / T0.size[1])) + (i2 * (i13 % T0.size[1])))]; + = T0[((i1 * (i13 / T0.logical_size[1])) + (i2 * (i13 % T0.logical_size[1])))]; } } NVFUSER_UPDATE_MAGIC_ZERO; @@ -321,7 +321,7 @@ TEST_F(LoopRotationTest, DoubleBuffered) { __global__ void CUDAGeneratedKernel(Tensor T0, Tensor T4) { NVFUSER_DEFINE_MAGIC_ZERO; Array a0; - a0 = (T0).stride; + a0 = (T0).alloc_stride; int64_t i1; i1 = a0[0]; int64_t i2; @@ -336,7 +336,7 @@ __global__ void CUDAGeneratedKernel(Tensor T0, Tensor int64_t i6; i6 = i1 * i4; bool b7; - b7 = (i4 + nvfuser_zero) < T0.size[0]; + b7 = (i4 + nvfuser_zero) < T0.logical_size[0]; #pragma unroll for(nvfuser_index_t i8 = 0; i8 < 3; ++i8) { T1[(i5 + i8)] = 0; @@ -358,7 +358,7 @@ __global__ void CUDAGeneratedKernel(Tensor T0, Tensor } NVFUSER_UPDATE_MAGIC_ZERO; #pragma unroll 1 - for(nvfuser_index_t i10 = 0; i10 < T0.size[0]; ++i10) { + for(nvfuser_index_t i10 = 0; i10 < T0.logical_size[0]; ++i10) { int64_t i11; i11 = 4 + i10; int64_t i12; @@ -370,7 +370,7 @@ __global__ void CUDAGeneratedKernel(Tensor T0, Tensor int64_t i15; i15 = 3 * ((1 + i10) % 5); bool b16; - b16 = i11 < T0.size[0]; + b16 = i11 < T0.logical_size[0]; #pragma unroll for(nvfuser_index_t i8 = 0; i8 < 3; ++i8) { T1[(i12 + i8)] = 0; @@ -438,7 +438,7 @@ TEST_F(LoopRotationTest, SelectDoubleBufferLoad) { __global__ void CUDAGeneratedKernel(Tensor T0, Tensor T4) { NVFUSER_DEFINE_MAGIC_ZERO; Array a0; - a0 = (T0).stride; + a0 = (T0).alloc_stride; int64_t i1; i1 = a0[1]; int64_t i2; @@ -448,7 +448,7 @@ __global__ void CUDAGeneratedKernel(Tensor T0, Tensor int64_t i4; i4 = 5 * i2; bool b5; - b5 = 4 < T0.size[0]; + b5 = 4 < T0.logical_size[0]; float T1[15]; #pragma unroll for(nvfuser_index_t i6 = 0; i6 < 3; ++i6) { @@ -468,7 +468,7 @@ __global__ void CUDAGeneratedKernel(Tensor T0, Tensor int64_t i9; i9 = i2 + (i2 * i7); bool b10; - b10 = ((1 + i7) + nvfuser_zero) < T0.size[0]; + b10 = ((1 + i7) + nvfuser_zero) < T0.logical_size[0]; #pragma unroll for(nvfuser_index_t i6 = 0; i6 < 3; ++i6) { T1[(i8 + i6)] = 0; @@ -503,7 +503,7 @@ __global__ void CUDAGeneratedKernel(Tensor T0, Tensor } NVFUSER_UPDATE_MAGIC_ZERO; #pragma unroll 1 - for(nvfuser_index_t i12 = 0; i12 < T0.size[0]; ++i12) { + for(nvfuser_index_t i12 = 0; i12 < T0.logical_size[0]; ++i12) { int64_t i13; i13 = 3 * i12; int64_t i14; @@ -513,7 +513,7 @@ __global__ void CUDAGeneratedKernel(Tensor T0, Tensor int64_t i16; i16 = 3 * ((1 + i12) % 5); bool b17; - b17 = (5 + i12) < T0.size[0]; + b17 = (5 + i12) < T0.logical_size[0]; float T3[3]; #pragma unroll for(nvfuser_index_t i18 = 0; i18 < 3; ++i18) { @@ -596,13 +596,13 @@ __global__ void CUDAGeneratedKernel(Tensor T0, Tensor unsigned smem_offset = 0; NVFUSER_DEFINE_MAGIC_ZERO; Tensor s0; - s0.stride = T0.stride; - s0.size = T0.size; s0.data = T0.data; + s0.logical_size = T0.logical_size; + s0.alloc_stride = T0.alloc_stride; float* ptr1; ptr1 = s0.data; Array a2; - a2 = s0.stride; + a2 = s0.alloc_stride; int64_t i3; i3 = a2[0]; int64_t i4; @@ -619,7 +619,7 @@ __global__ void CUDAGeneratedKernel(Tensor T0, Tensor unsigned i8; i8 = (toSmem((T4))) + (12 * i6); bool b9; - b9 = (i6 + nvfuser_zero) < T0.size[0]; + b9 = (i6 + nvfuser_zero) < T0.logical_size[0]; #pragma unroll for(nvfuser_index_t i10 = 0; i10 < 3; ++i10) { Ampere::cpAsyncCa((i8 + (4 * i10)), (ptr7 + (i4 * (i10 + nvfuser_zero))), b9); @@ -632,7 +632,7 @@ __global__ void CUDAGeneratedKernel(Tensor T0, Tensor T1[0] = T4[0]; #pragma unroll 1 - for(nvfuser_index_t i11 = 0; i11 < T0.size[0]; ++i11) { + for(nvfuser_index_t i11 = 0; i11 < T0.logical_size[0]; ++i11) { float* ptr12; ptr12 = ptr5 + (i3 * i11); int64_t i13; @@ -644,7 +644,7 @@ __global__ void CUDAGeneratedKernel(Tensor T0, Tensor int64_t i16; i16 = 3 * i11; bool b17; - b17 = i13 < T0.size[0]; + b17 = i13 < T0.logical_size[0]; #pragma unroll for(nvfuser_index_t i10 = 0; i10 < 3; ++i10) { Ampere::cpAsyncCa((i14 + (4 * i10)), (ptr12 + (i4 * (i10 + nvfuser_zero))), b17); diff --git a/test/test_tensor_factories.cpp b/test/test_tensor_factories.cpp index 9c38262c444..9e05fe20bf8 100644 --- a/test/test_tensor_factories.cpp +++ b/test/test_tensor_factories.cpp @@ -539,10 +539,10 @@ TEST_F(TensorFactoryTest, MetadataAsTensor) { auto meta0_copy2 = set(meta0_copy1); auto meta1_copy2 = set(meta1_copy1); - auto size0 = IrBuilder::getAttrExpr(meta0_copy2, "size"); - auto stride0 = IrBuilder::getAttrExpr(meta0_copy2, "stride"); - auto size1 = IrBuilder::getAttrExpr(meta1_copy2, "size"); - auto stride1 = IrBuilder::getAttrExpr(meta1_copy2, "stride"); + auto size0 = IrBuilder::getAttrExpr(meta0_copy2, "logical_size"); + auto stride0 = IrBuilder::getAttrExpr(meta0_copy2, "alloc_stride"); + auto size1 = IrBuilder::getAttrExpr(meta1_copy2, "logical_size"); + auto stride1 = IrBuilder::getAttrExpr(meta1_copy2, "alloc_stride"); auto output = tensor(std::vector{size0, stride0, size1, stride1}); fusion->addOutput(output);