From 7dcd53e59004cf08269fe1a5f67fabbe31aa6b05 Mon Sep 17 00:00:00 2001 From: Christian Sarofeen Date: Wed, 29 Jan 2025 13:35:57 -0800 Subject: [PATCH 01/16] Move all ir evaluation definitions to one file. --- CMakeLists.txt | 1 + csrc/ir/base_nodes.cpp | 44 -- csrc/ir/evaluate.cpp | 1073 ++++++++++++++++++++++++++++++++++++++++ csrc/ir/nodes.cpp | 1018 -------------------------------------- 4 files changed, 1074 insertions(+), 1062 deletions(-) create mode 100644 csrc/ir/evaluate.cpp diff --git a/CMakeLists.txt b/CMakeLists.txt index 6e4127e72c4..5e0a6300370 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -155,6 +155,7 @@ list(APPEND NVFUSER_SRCS ${NVFUSER_SRCS_DIR}/ir/builder.cpp ${NVFUSER_SRCS_DIR}/ir/cloner.cpp ${NVFUSER_SRCS_DIR}/ir/container.cpp + ${NVFUSER_SRCS_DIR}/ir/evaluate.cpp ${NVFUSER_SRCS_DIR}/ir/graphviz.cpp ${NVFUSER_SRCS_DIR}/ir/iostream.cpp ${NVFUSER_SRCS_DIR}/ir/nodes.cpp diff --git a/csrc/ir/base_nodes.cpp b/csrc/ir/base_nodes.cpp index 6e7e53e0f4d..88b7e93a18c 100644 --- a/csrc/ir/base_nodes.cpp +++ b/csrc/ir/base_nodes.cpp @@ -6,7 +6,6 @@ */ // clang-format on #include -#include #include #include #include @@ -200,20 +199,6 @@ bool Val::isConstInt() const { return ir_utils::dependenciesSatisfied(this) && isIntegralScalar(); } -PolymorphicValue Val::evaluate() { - if (this->value().hasValue()) { - return this->value(); - } - - ExpressionEvaluator ee; - auto evaluated_val = ee.evaluate(this); - NVF_ERROR( - evaluated_val.hasValue(), - "Detected a const value but failed to infer its value: ", - toInlineString()); - return evaluated_val; -} - bool Val::isZero() const { return value().hasValue() && (bool)(value() == 0.0); } @@ -376,33 +361,4 @@ Expr* Expr::withWritePredicate(kir::Predicate* predicate) { return result; } -std::vector Expr::evaluate( - const ExpressionEvaluator& ee, - const std::vector& inputs) const { - NVF_THROW( - "`evaluate` method for expression ", - getOpString(), - " is not defined. ", - "Please override the evaluate method"); -} - -std::vector Expr::evaluate( - const ExpressionEvaluator& ee, - std::unordered_map& known_values) const { - std::vector expr_inputs; - expr_inputs.reserve(inputs().size()); - for (auto inp : inputs()) { - const auto& eval_i = ee.evaluate(inp, known_values); - if (!eval_i.hasValue()) { - return {std::monostate{}}; - } - expr_inputs.emplace_back(eval_i); - } - return this->evaluate(ee, expr_inputs); -} - -void Expr::addDataAttribute(PolymorphicValue attr) { - addAttribute(IrBuilder::createInContainer(container(), std::move(attr))); -} - } // namespace nvfuser diff --git a/csrc/ir/evaluate.cpp b/csrc/ir/evaluate.cpp new file mode 100644 index 00000000000..3d6f7ca63a6 --- /dev/null +++ b/csrc/ir/evaluate.cpp @@ -0,0 +1,1073 @@ +#include +#include +#include +#include +#include + +namespace nvfuser { + +PolymorphicValue Val::evaluate() { + if (this->value().hasValue()) { + return this->value(); + } + + ExpressionEvaluator ee; + auto evaluated_val = ee.evaluate(this); + NVF_ERROR( + evaluated_val.hasValue(), + "Detected a const value but failed to infer its value: ", + toInlineString()); + return evaluated_val; +} + +std::vector Expr::evaluate( + const ExpressionEvaluator& ee, + const std::vector& inputs) const { + NVF_THROW( + "`evaluate` method for expression ", + getOpString(), + " is not defined. ", + "Please override the evaluate method"); +} + +std::vector Expr::evaluate( + const ExpressionEvaluator& ee, + std::unordered_map& known_values) const { + std::vector expr_inputs; + expr_inputs.reserve(inputs().size()); + for (auto inp : inputs()) { + const auto& eval_i = ee.evaluate(inp, known_values); + if (!eval_i.hasValue()) { + return {std::monostate{}}; + } + expr_inputs.emplace_back(eval_i); + } + return this->evaluate(ee, expr_inputs); +} + +void Expr::addDataAttribute(PolymorphicValue attr) { + addAttribute(IrBuilder::createInContainer(container(), std::move(attr))); +} +std::vector FullOp::evaluate( + const ExpressionEvaluator& ee, + const std::vector& inputs) const { + std::vector shape; + for (auto i : c10::irange(inputs.size() - 1)) { + shape.push_back(inputs.at(i).as()); + } + DataType dtype = getFillValue()->getDataType().value(); + const auto options = + at::TensorOptions().device(at::kCUDA).dtype(data_type_to_aten(dtype)); + using namespace PolymorphicValue_functions; + return {at::full(shape, toScalar(inputs.back()), options)}; +} + +std::vector SelectOp::evaluate( + const ExpressionEvaluator& ee, + const std::vector& inputs) const { + const auto& in = inputs.at(0).as(); + int64_t dimension = dim(); + int64_t index = (int64_t)inputs.at(1); + return {in.select(dimension, index)}; +} + +std::vector IndexSelectOp::evaluate( + const ExpressionEvaluator& ee, + const std::vector& inputs) const { + const auto& in = inputs.at(0).as(); + int64_t dimension = dim(); + const auto& indices = inputs.at(1).as().squeeze(); + return {at::index_select(in, dimension, indices)}; +} + +std::vector TorchGatherOp::evaluate( + const ExpressionEvaluator& ee, + const std::vector& inputs) const { + const auto& input = inputs.at(0).as(); + const auto& index = inputs.at(1).as(); + auto dimension = dim(); + if (exactSizes()) { + return {at::take_along_dim(input, index, dimension)}; + } else { + return {at::gather(input, dimension, index)}; + } +} + +std::vector ScatterOp::evaluate( + const ExpressionEvaluator& ee, + const std::vector& inputs) const { + const auto& input = inputs.at(0).as(); + const auto& index = inputs.at(1).as(); + const auto& src = inputs.at(2).as(); + auto dimension = dim(); + return {at::scatter(input, dimension, index, src)}; +} + +std::vector IotaOp::evaluate( + const ExpressionEvaluator& ee, + const std::vector& inputs) const { + const auto options = + at::TensorOptions().device(at::kCUDA).dtype(data_type_to_aten(dtype())); + int64_t length = (int64_t)inputs.at(0); + + if (isIntegralType(dtype())) { + int64_t start = (int64_t)inputs.at(1); + int64_t step = (int64_t)inputs.at(2); + int64_t end = start + step * length; + return {at::arange(start, end, step, options)}; + } else if (isFloatingPointType(dtype())) { + double start = (double)inputs.at(1); + double step = (double)inputs.at(2); + // Due to rounding error, it can be hard to guarantee the size of + // the output of arange to be exactly length, so we generate a + // larger tensor and truncate it to length. + double end = start + step * ((double)length + 1); + return {at::arange(start, end, step, options).narrow(0, 0, length)}; + } else { + NVF_THROW("Unsupported dtype in IotaOp evaluator: ", dtype()); + } +} + +std::vector EyeOp::evaluate( + const ExpressionEvaluator& ee, + const std::vector& inputs) const { + const auto options = + at::TensorOptions().device(at::kCUDA).dtype(data_type_to_aten(dtype())); + int64_t nrows = (int64_t)inputs.at(0); + if (inputs.size() > 1) { + int64_t ncols = (int64_t)inputs.at(1); + return {at::eye(nrows, ncols, options)}; + } else { + return {at::eye(nrows, options)}; + } +} + +std::vector UnaryOp::evaluate( + const ExpressionEvaluator& ee, + const std::vector& inputs) const { + using namespace PolymorphicValue_functions; + + const auto& in = inputs.at(0); + if (!in.hasValue()) { + return {std::monostate{}}; + } + + switch (getUnaryOpType()) { + case UnaryOpType::Neg: + return {-in}; + case UnaryOpType::Cast: + if (in.is()) { + return {PolymorphicValue( + in.as().to(data_type_to_aten(out()->dtype())))}; + } else if (isIntegralType(*out()->getDataType())) { + return {PolymorphicValue((int64_t)in)}; + } else if (isFloatingPointType(*out()->getDataType())) { + return {PolymorphicValue((double)in)}; + } else if (out()->getDataType() == DataType::Bool) { + return {PolymorphicValue((bool)in)}; + } else if (isComplexType(*out()->getDataType())) { + return {PolymorphicValue((std::complex)in)}; + } else { + NVF_THROW("dtype not supported in evaluator: ", *out()->getDataType()); + } + case UnaryOpType::Reciprocal: + return {1.0 / in}; + break; + case UnaryOpType::Abs: + return {abs(in)}; + break; + case UnaryOpType::LogicalNot: + return {!in}; + break; + case UnaryOpType::BitwiseNot: + return {~in}; + break; + case UnaryOpType::Erf: + return {erf(in)}; + break; + case UnaryOpType::ToUnsignedSmemAddr: + return {(int64_t)(unsigned)in}; + break; + case UnaryOpType::AdjustPartialLdMatrixAddrInTuring8: + case UnaryOpType::AdjustPartialLdMatrixAddrInTuring16: + return {in}; + break; + case UnaryOpType::Dereference: + if (*out()->getDataType() == DataType::Float) { + return {PolymorphicValue((double)*(float*)in)}; + } else { + NVF_THROW("dtype not supported in evaluator: ", *out()->getDataType()); + } + break; + case UnaryOpType::Sigmoid: + return {in.as().sigmoid()}; + break; + case UnaryOpType::Tanh: + return {in.as().tanh()}; + break; + case UnaryOpType::Relu: + return {at::relu(in.as())}; + break; + case UnaryOpType::Gelu: + return {at::gelu(in.as())}; + break; + case UnaryOpType::Exp: + return {at::exp(in.as())}; + break; + case UnaryOpType::Sin: + return {in.as().sin()}; + break; + case UnaryOpType::Signbit: + return {signbit(in)}; + break; + case UnaryOpType::Cos: + return {in.as().cos()}; + break; + case UnaryOpType::BitCast: + NVF_CHECK( + dataTypeSize(input(0)->dtype()) == dataTypeSize(out()->dtype()), + "BitCast only works for types of the same size"); + if (isComplexType(input(0)->dtype()) && + std::holds_alternative(out()->dtype().type)) { + // view_as_real case. + auto vec_type = std::get(out()->dtype().type); + auto inp_scalar_type = getTypeFromComplexType(input(0)->dtype()); + NVF_CHECK( + *vec_type.type == inp_scalar_type, + "Output type must be the same as the scalar type of the complex input."); + NVF_CHECK( + vec_type.size == 2, + "Expected output to be array of size 2, found array of size ", + vec_type.size); + return {in.as()}; + } else { + return {in.as().view(data_type_to_aten(out()->dtype()))}; + } + break; + case UnaryOpType::Rsqrt: + return {in.as().rsqrt()}; + break; + case UnaryOpType::Real: + return {at::real(in.as())}; + break; + case UnaryOpType::Imag: + return {at::imag(in.as())}; + break; + case UnaryOpType::Tan: + return {in.as().tan()}; + break; + case UnaryOpType::IsFinite: + return {at::isfinite(in.as())}; + break; + default: + NVF_CHECK( + false, + "Unexpected operator type ", + getUnaryOpType(), + " in ", + toString()); + } +} + +std::vector BinaryOp::evaluate( + const ExpressionEvaluator& ee, + const std::vector& inputs) const { + using namespace PolymorphicValue_functions; + const auto& lhs = inputs.at(0); + const auto& rhs = inputs.at(1); + + switch (getBinaryOpType()) { + case BinaryOpType::Add: + return {lhs + rhs}; + break; + case BinaryOpType::Sub: + return {lhs - rhs}; + break; + case BinaryOpType::Mul: + return {lhs * rhs}; + break; + case BinaryOpType::Div: + return {lhs / rhs}; + break; + case BinaryOpType::Mod: + NVF_CHECK(rhs != 0); + return {lhs % rhs}; + break; + case BinaryOpType::Fmod: + NVF_CHECK(rhs != 0); + return {fmod(lhs, rhs)}; + break; + case BinaryOpType::CeilDiv: + NVF_CHECK(rhs != 0); + return {ceildiv(lhs, rhs)}; + break; + case BinaryOpType::LogicalAnd: + return {lhs && rhs}; + break; + case BinaryOpType::LogicalOr: + return {lhs || rhs}; + break; + case BinaryOpType::BitwiseAnd: + return {lhs & rhs}; + break; + case BinaryOpType::BitwiseOr: + return {lhs | rhs}; + break; + case BinaryOpType::BitwiseXor: + return {lhs ^ rhs}; + break; + case BinaryOpType::Eq: + return {eq(lhs, rhs)}; + break; + case BinaryOpType::NE: + return {ne(lhs, rhs)}; + break; + case BinaryOpType::GT: + return {gt(lhs, rhs)}; + break; + case BinaryOpType::GE: + return {ge(lhs, rhs)}; + break; + case BinaryOpType::LT: + return {lt(lhs, rhs)}; + break; + case BinaryOpType::LE: + return {le(lhs, rhs)}; + break; + case BinaryOpType::Max: + return {max(lhs, rhs)}; + break; + case BinaryOpType::Min: + return {min(lhs, rhs)}; + break; + case BinaryOpType::Gcd: + return {gcd(lhs, rhs)}; + break; + case BinaryOpType::Lshift: + return {lhs << rhs}; + break; + case BinaryOpType::Rshift: + return {lhs >> rhs}; + break; + case BinaryOpType::Complex: + return {at::complex(lhs.as(), rhs.as())}; + break; + case BinaryOpType::Pow: + return {pow(lhs, rhs)}; + break; + default: + NVF_CHECK( + false, + "Unexpected operator type: ", + getBinaryOpType(), + " in ", + toString()); + } +} + +std::vector TernaryOp::evaluate( + const ExpressionEvaluator& ee, + const std::vector& inputs) const { + using namespace PolymorphicValue_functions; + const auto& a = inputs.at(0); + const auto& b = inputs.at(1); + const auto& c = inputs.at(2); + switch (getTernaryOpType()) { + case TernaryOpType::Clamp: + return {std::min(std::max(a, b), c)}; + break; + case TernaryOpType::Lerp: + // This is the same lerp computed in helpers.cu + // https://math.stackexchange.com/a/1798323 + return {(c < 0.5) ? a + c * (b - a) : b - (b - a) * (1.0 - c)}; + break; + case TernaryOpType::Threshold: + return {(a <= b) ? c : a}; + break; + case TernaryOpType::Where: + return {a.as() ? b : c}; + break; + default: + NVF_CHECK( + false, + "Unexpected operator type: ", + getTernaryOpType(), + " in ", + toString()); + } +} + +std::vector ArrayConstruct::evaluate( + const ExpressionEvaluator& ee, + const std::vector& inputs) const { + return {PolymorphicValue(inputs)}; +} + +std::vector ReverseArray::evaluate( + const ExpressionEvaluator& ee, + const std::vector& inputs) const { + NVF_ERROR(inputs.size() == 1, "ReverseArray expects 1 input"); + PolymorphicValue array = inputs.at(0); + auto& vec = array.as(); + std::reverse(vec.begin(), vec.end()); + return {std::move(array)}; +} + +std::vector GetItem::evaluate( + const ExpressionEvaluator& ee, + const std::vector& inputs) const { + NVF_ERROR(inputs.size() == 2, "GetItem expects 2 inputs"); + return {PolymorphicValue(inputs.at(0)[inputs.at(1)])}; +} + +std::vector StructConstruct::evaluate( + const ExpressionEvaluator& ee, + const std::vector& inputs) const { + NVF_ERROR( + this->inputs().size() == inputs.size(), + "StructConstruct expects ", + this->inputs().size(), + " inputs"); + PolymorphicValue struct_ = + std::get(output(0)->dtype().type).create(); + for (int64_t i : c10::irange((int64_t)inputs.size())) { + struct_->*attribute(i) = inputs.at(i); + } + return {std::move(struct_)}; +} + +std::vector GetAttr::evaluate( + const ExpressionEvaluator& ee, + const std::vector& inputs) const { + NVF_ERROR(inputs.size() == 1, "GetAttr expects 1 input"); + return {inputs.at(0)->*attr()}; +} + +std::vector TensorConstruct::evaluate( + const ExpressionEvaluator& ee, + const std::vector& inputs) const { + NVF_ERROR(inputs.size() == 1, "TensorConstruct expects 1 input"); + using namespace PolymorphicValue_functions; + return {toTensor(inputs.at(0))}; +} + +std::vector BroadcastOp::evaluate( + const ExpressionEvaluator& ee, + const std::vector& inputs) const { + NVF_ERROR( + inputs.size() == 1, + "BroadcastOp expects exactly 1 input, but received ", + inputs.size()); + std::vector out_shape; + const auto& in = inputs.at(0).as(); + int64_t idx = 0; + for (bool b : getBroadcastDimFlags()) { + if (b) { + out_shape.push_back(1); + } else { + out_shape.push_back(in.sizes()[idx++]); + } + } + return {in.view(out_shape)}; +} + +std::vector SqueezeOp::evaluate( + const ExpressionEvaluator& ee, + const std::vector& inputs) const { + NVF_ERROR( + inputs.size() == 1, + "SqueezeOp expects exactly 1 input, but received ", + inputs.size()); + std::vector out_shape; + const auto& in = inputs.at(0).as(); + const auto& is_squeeze_dims = getSqueezeDimFlags(); + NVF_ERROR( + (int64_t)is_squeeze_dims.size() == in.dim(), + "The dimensions of input tensor and does not match with is_squeeze_dims"); + at::Tensor out = in; + for (int64_t i : c10::irange((int64_t)is_squeeze_dims.size())) { + if (is_squeeze_dims[i]) { + if (in.stride(i) == 0) { + // If the input dimension is expanded in this dimension, undo the expand + // by slicing. This ensures that any broadcast dimensions will be + // unexpanded when we do the final call to view() + out = out.slice(i, 0, 1); + } + } else { + out_shape.push_back(in.sizes()[i]); + } + } + return {out.view(out_shape)}; +} + +std::vector ReductionOp::evaluate( + const ExpressionEvaluator& ee, + const std::vector& inputs) const { + const auto& input = inputs.at(0).as(); + const auto output = out()->as(); + + NVF_ERROR( + !output->hasRoot(), + "Evaluation for rFactored reductions is not supported."); + + std::vector reduction_axes; + for (const auto i : c10::irange(int64_t(output->getLogicalDomain().size()))) { + auto ax = output->getLogicalDomain().at(i); + if (ax->isReduction()) { + reduction_axes.push_back(i); + } + } + switch (getReductionOpType()) { + case BinaryOpType::Add: + return {at::sum(input, reduction_axes)}; + break; + case BinaryOpType::Max: + return {at::amax(input, reduction_axes)}; + break; + case BinaryOpType::Min: + return {at::amin(input, reduction_axes)}; + break; + default: + NVF_CHECK( + false, + "Unexpected operator type: ", + getReductionOpType(), + " in ", + toString()); + } +} + +std::vector GroupedReductionOp::evaluate( + const ExpressionEvaluator& ee, + const std::vector& inputs) const { + const auto num_reductions = numHorizontallyGroupedExprs(); + std::vector grouped_reduction_out; + grouped_reduction_out.reserve(num_reductions); + for (const auto i : c10::irange(num_reductions)) { + const auto& in_tensor = inputs.at(i).as(); + const auto out_tv = output(i)->as(); + NVF_ERROR( + !out_tv->hasRoot(), + "Evaluation for rFactored reductions is not supported."); + + std::vector reduction_axes; + for (const auto id : + c10::irange(int64_t(out_tv->getLogicalDomain().size()))) { + auto ax = out_tv->getLogicalDomain().at(id); + if (ax->isReduction()) { + reduction_axes.push_back(id); + } + } + switch (getReductionOpType(i)) { + case BinaryOpType::Add: + grouped_reduction_out.emplace_back(at::sum(in_tensor, reduction_axes)); + break; + case BinaryOpType::Max: + grouped_reduction_out.emplace_back(at::amax(in_tensor, reduction_axes)); + break; + default: + NVF_CHECK( + false, + "Unexpected operator type: ", + getReductionOpType(i), + " in ", + toString()); + } + } + return grouped_reduction_out; +} + +std::vector WelfordOp::evaluate( + const ExpressionEvaluator& ee, + const std::vector& inputs) const { + NVF_ERROR( + !hasInit(), + "Evaluation for WelfordOp is not implemented for non-empty initial values."); + const auto& in_tensor = inputs.at(0).as(); + const auto out_tv = out()->as(); + NVF_ERROR( + !out_tv->hasRoot(), + "Evaluation for WelfordOp is not supported when output is rFactored."); + + int64_t N = 1; + std::vector reduction_axes; + for (const auto i : c10::irange(int64_t(out_tv->getLogicalDomain().size()))) { + auto ax = out_tv->getLogicalDomain().at(i); + if (ax->isReduction()) { + reduction_axes.push_back(i); + N *= in_tensor.size(i); + } + } + const auto [in_var, in_avg] = + at::var_mean(in_tensor, reduction_axes, false, false); + return {in_avg, in_var * N, N}; +} + +std::vector ExpandOp::evaluate( + const ExpressionEvaluator& ee, + const std::vector& inputs) const { + const auto& in = inputs.at(0).as(); + std::vector expanded_size; + for (auto i : c10::irange(1, inputs.size())) { + expanded_size.push_back((int64_t)inputs.at(i)); + } + return {in.expand(expanded_size)}; +} + +std::vector RepeatOp::evaluate( + const ExpressionEvaluator& ee, + const std::vector& inputs) const { + NVF_ERROR( + inputs.size() == 1, + "RepeatOp expects exactly 1 input, but received ", + inputs.size()); + auto tensor = inputs.at(0).as(); + std::vector multipliers; + multipliers.reserve(out()->getLogicalDomain().size()); + const auto c2p = + PairwiseLogicalDomainMap(in(), out()).mapConsumerToProducer(); + for (const auto i : c10::irange(out()->getLogicalDomain().size())) { + auto out_id = out()->getLogicalDomain().at(i); + auto inp_id = c2p.at(out_id); + auto out_extent = ee.evaluate(out_id->extent()).as(); + auto inp_extent = ee.evaluate(inp_id->extent()).as(); + NVF_ERROR( + out_extent % inp_extent == 0, + "For dimension ", + i, + ", the output extent (", + out_extent, + " should be a multiple of the input extent (", + inp_extent, + ")."); + multipliers.push_back(out_extent / inp_extent); + } + return {tensor.repeat(multipliers)}; +} + +std::vector ViewAsScalar::evaluate( + const ExpressionEvaluator& ee, + const std::vector& inputs) const { + const at::Tensor& in = inputs.at(0).as(); + return {at::view_as_real(in)}; +} + +std::vector ViewOp::evaluate( + const ExpressionEvaluator& ee, + const std::vector& inputs) const { + FUSER_PERF_SCOPE("ViewOp::evaluate"); + NVF_ERROR(inputs.size() == 1); + const at::Tensor& in_tensor = inputs[0].as(); + + const std::vector& out_logical = out()->getLogicalDomain(); + std::vector out_shape; + out_shape.reserve(out_logical.size()); + for (IterDomain* id : out_logical) { + if (id->isDeviceDim()) { + out_shape.push_back(1); + } else { + out_shape.push_back( + ee.evaluate(id->getMaybeExpandedExtent()).as()); + } + } + + // TODO: check allocation domain and contiguity. + + // Use `at::Tensor::reshape` instead of `at::Tensor::view` because `ViewOp` + // doesn't always produce an alias. For example, when merging an expanded + // `IterType::Broadcast` and an `IterType::Iteration`, `ViewOp` has to realize + // the expand. + return {in_tensor.reshape(out_shape)}; +} + +std::vector LoadStoreOp::evaluate( + const ExpressionEvaluator& ee, + const std::vector& inputs) const { + FUSER_PERF_SCOPE("LoadStoreOp::evaluate"); + if (TensorView* out_tv = dynamic_cast(out())) { + if (out_tv->hasRoot()) { + std::optional> permutation = + ir_utils::computePermutation( + out_tv->getRootDomain(), out_tv->getLogicalDomain()); + NVF_ERROR( + permutation.has_value(), + "The logical domain of a Set.Permute is supposed to be a permutation of the root domain: ", + out_tv->toString()); + NVF_ERROR(inputs.size() == 1); + at::Tensor in_tensor = inputs[0].as(); + at::Tensor out_tensor = in_tensor.permute(*permutation); + return {out_tensor}; + } + } + return inputs; +} + +std::vector PadOp::evaluate( + const ExpressionEvaluator& ee, + const std::vector& inputs) const { + const auto& in = inputs.at(0).as(); + + std::vector pad_widths; + auto pad_width_offset = getPadWidthInputOffset(); + auto num_dims = in.dim(); + + for (auto i = num_dims - 1; i > -1; i--) { + auto left_pad = (int64_t)inputs.at(pad_width_offset + 2 * i); + auto right_pad = (int64_t)inputs.at(pad_width_offset + 2 * i + 1); + pad_widths.push_back(left_pad); + pad_widths.push_back(right_pad); + } + + if (isComplexType(*out()->getDataType())) { + std::complex value = + static_cast>(inputs.at(1)); + auto real = at::real(in); + auto imag = at::imag(in); + auto padded_real = at::pad(real, pad_widths, "constant", value.real()); + auto padded_imag = at::pad(imag, pad_widths, "constant", value.imag()); + return {at::complex(padded_real, padded_imag)}; + } else { + double value = static_cast(inputs.at(1)); + return {at::pad(in, pad_widths, "constant", value)}; + } +} + +std::vector SliceOp::evaluate( + const ExpressionEvaluator& ee, + const std::vector& inputs) const { + const auto& in = inputs.at(0).as(); + std::vector ranges; + auto ranges_offset = getRangeInputOffset(); + auto num_dims = in.dim(); + for (const auto i : c10::irange(num_dims)) { + auto start = (int64_t)inputs.at(ranges_offset + 3 * i); + auto stop = (int64_t)inputs.at(ranges_offset + 3 * i + 1); + auto step = (int64_t)inputs.at(ranges_offset + 3 * i + 2); + ranges.emplace_back(at::indexing::Slice(start, stop, step)); + } + return {in.index(ranges)}; +} + +std::vector CatOp::evaluate( + const ExpressionEvaluator& ee, + std::unordered_map& known_values) const { + // CatOp is preceded by a PadOp internally. + // For ATen evaluation, directly compute the unpadded inputs. + std::vector unpadded_inputs; + unpadded_inputs.reserve(inputs().size()); + int64_t concat_dim = concatenatedDim(); + for (Val* inp : inputs()) { + NVF_CHECK( + inp->definition() != nullptr && inp->definition()->isA(), + "Expected CatOp to be preceded by a PadOp."); + auto eval_i = ee.evaluate(inp->definition()->input(0), known_values); + unpadded_inputs.push_back(eval_i.as()); + } + return {at::cat(unpadded_inputs, concat_dim)}; +} + +std::vector MatmulOp::evaluate( + const ExpressionEvaluator& ee, + const std::vector& inputs) const { + const auto a = inputs.at(0).as(); + const auto b = inputs.at(1).as(); + + auto matmul_out = at::matmul(a, b); + + // When the contracting dimension is sharded, each device has a partial + // matmul output and is followed by an allreduce. For loop split, this is + // represented as an rfactored reduction. The local matmul logical domain + // after the rfactor is: i{DIDx}, i{M}, i{N}, r{K//d}. Unsqueeze the + // rfactored DID axis to correctly bind with the logical domain. See + // tests/python/test_multidevice.py/test_matmul_allreduce_loop_split + auto out_logical = TensorDomain::noReductions(out()->getLogicalDomain()); + int64_t rfactor_did_idx = -1; + for (auto idx : c10::irange(static_cast(out_logical.size()))) { + if (!out_logical.at(idx)->isRFactorProduct() || + !out_logical.at(idx)->isDeviceDim()) { + continue; + } + if (rfactor_did_idx != -1) { + NVF_THROW( + "Expected only 1 rfactored DID iterdomain, found at least 2 in ", + out_logical); + } + rfactor_did_idx = idx; + } + + if (rfactor_did_idx != -1) { + matmul_out = matmul_out.unsqueeze(rfactor_did_idx); + } + + const auto& [sizes, strides] = inferShapeOfOutput(out(), ee); + auto meta_out = at::detail::empty_strided_meta(sizes, strides, a.dtype()); + + if (meta_out.is_contiguous()) { + return {matmul_out}; + } + + auto strided_matmul_out = at::empty_strided(sizes, strides, a.options()); + strided_matmul_out = strided_matmul_out.copy_(matmul_out); + return {strided_matmul_out}; +} + +std::vector LinearOp::evaluate( + const ExpressionEvaluator& ee, + const std::vector& inputs) const { + const auto in = inputs.at(0).as(); + auto weight = inputs.at(1).as(); + + auto squeeze_device_dims = [](at::Tensor& t, + int64_t num_device_dims) -> void { + // Record the initial shape for the error message. + std::vector shape = t.sizes().vec(); + for ([[maybe_unused]] auto _ : c10::irange(num_device_dims)) { + NVF_CHECK( + t.size(0) == 1, + "When the weight is >2D, expect its preceding dimensions and " + "the bias's preceding dimensions to " + "be DID-parallel and therefore size-1: ", + shape); + t = t.squeeze(0); + } + }; + + // The squeezes and unsqueezes are currently required to support a sharded + // linear layer. Remove them after #2563. + auto num_device_dims = weight.dim() - 2; + squeeze_device_dims(weight, num_device_dims); + + at::Tensor out; + if (has_bias()) { + auto bias = inputs.at(2).as(); + squeeze_device_dims(bias, num_device_dims); + out = at::linear(in, weight, bias); + } else { + out = at::linear(in, weight); + } + + for ([[maybe_unused]] auto _ : c10::irange(num_device_dims)) { + out = out.unsqueeze(0); + } + return {out}; +} + +std::vector SdpaFwdOp::evaluate( + const ExpressionEvaluator& ee, + const std::vector& inputs) const { + auto query = inputs.at(0).as(); + auto key = inputs.at(1).as(); + auto value = inputs.at(2).as(); + + const auto dropout_p = inputs.at(3).as(); + const auto is_causal = inputs.at(4).as(); + + // Temporary handling of DID parallelization see + // https://github.com/NVIDIA/Fuser/issues/2563 + bool handle_device_dim = false; + if (query.dim() == 5) { + handle_device_dim = true; + + NVF_CHECK(key.dim() == 5 && value.dim() == 5); + + auto query_domain = + TensorDomain::noReductions(this->query()->getLogicalDomain()); + auto key_domain = + TensorDomain::noReductions(this->key()->getLogicalDomain()); + auto value_domain = + TensorDomain::noReductions(this->value()->getLogicalDomain()); + NVF_CHECK( + query_domain.front()->isDeviceDim(), + "Only support DID parallelization on outermost axis"); + NVF_CHECK( + key_domain.front()->isDeviceDim(), + "Only support DID parallelization on outermost axis"); + NVF_CHECK( + value_domain.front()->isDeviceDim(), + "Only support DID parallelization on outermost axis"); + + query = query.squeeze(0); + key = key.squeeze(0); + value = value.squeeze(0); + } + + // Flash attention requires the last dimension to be padded to 8. + // https://github.com/pytorch/pytorch/blob/c27882ffa8c1c7e4cf8ebc6c2f879e5b6c8814ad/aten/src/ATen/native/transformers/attention.cpp#L675-L677 + const auto last_dim_size = query.size(-1); + auto pad_last_dim = [last_dim_size]( + at::Tensor inp, int alignment_size) -> at::Tensor { + if (last_dim_size % alignment_size == 0) { + return inp; + } + auto pad_count = alignment_size - (last_dim_size % alignment_size); + auto padded_inp = at::pad(inp, {0, pad_count}); + return padded_inp; + }; + + query = pad_last_dim(query, 8); + key = pad_last_dim(key, 8); + value = pad_last_dim(value, 8); + + // Conmpute scale using original size of last dimension + double scale = inputs.size() > 5 ? inputs.back().as() + : 1.0 / std::sqrt(last_dim_size); + + // ATen reference: + // https://github.com/pytorch/pytorch/blob/c27882ffa8c1c7e4cf8ebc6c2f879e5b6c8814ad/aten/src/ATen/native/transformers/attention.cpp#L680-L681 + auto + [output, + log_sumexp, + cum_seq_q, + cum_seq_k, + query_seq_len, + key_seq_len, + philox_seed, + philox_offset, + debug_attn_mask] = + at::_scaled_dot_product_flash_attention( + query, + key, + value, + dropout_p, + is_causal, + /*return_debug_mask=*/false, + scale); + + // If the inputs were padded, slice the output to restore the original + // size + if (output.size(-1) != last_dim_size) { + output = output.slice(-1, 0, last_dim_size); + } + + // Add back the device dim axis for output. + if (handle_device_dim) { + output = output.unsqueeze(0); + log_sumexp = log_sumexp.unsqueeze(0); + } + + // We ignore cum_seq_q/k outputs since they are undefined tensors for + // non-nested tensors. We do not store query/key_seq_len since they can be + // computed in non-nested tensor directly. debug_attn_mask is ignored + // since `return_debug_mask=false`. + return {output, log_sumexp, philox_seed, philox_offset}; +} + +std::vector SdpaBwdOp::evaluate( + const ExpressionEvaluator& ee, + const std::vector& inputs) const { + // Backward tensor inputs: grad_input, query, key, value, output, + // logsumexp, max_q/k Temporary handling of DID parallelization. See + // https://github.com/NVIDIA/Fuser/issues/2563 + bool first_dim_is_did = this->key()->as()->axis(0)->isDeviceDim(); + auto out_grad = inputs[0].as(); + if (first_dim_is_did) { + NVF_CHECK(out_grad.dim() == 5, "Expected 5D but found ", out_grad.sizes()); + } else { + NVF_CHECK(out_grad.dim() == 4, "Expected 4D but found ", out_grad.sizes()); + } + + std::vector bwd_inputs; + for (auto idx : c10::irange(6)) { + auto in_tensor = inputs.at(idx).as(); + // Removing the size 1 from sharded axis from tensors. + if (first_dim_is_did) { + in_tensor = in_tensor.squeeze(0); + } + bwd_inputs.push_back(in_tensor); + } + const auto dropout_p = inputs.at(6).as(); + const auto is_causal = inputs.at(7).as(); + const auto philox_seed = inputs.at(8).as(); + const auto philox_offset = inputs.at(9).as(); + + // Flash attention requires the last dimension to be padded to 8. + // https://github.com/pytorch/pytorch/blob/c27882ffa8c1c7e4cf8ebc6c2f879e5b6c8814ad/aten/src/ATen/native/transformers/attention.cpp#L675-L677 + const auto last_dim_size = bwd_inputs[0].size(-1); + auto pad_last_dim = [last_dim_size]( + at::Tensor inp, int alignment_size) -> at::Tensor { + if (last_dim_size % alignment_size == 0) { + return inp; + } + auto pad_count = alignment_size - (last_dim_size % alignment_size); + auto padded_inp = at::pad(inp, {0, pad_count}); + return padded_inp; + }; + + // Conmpute scale using original size of last dimension + double scale = inputs.size() > 10 ? inputs.back().as() + : 1.0 / std::sqrt(last_dim_size); + + // ATen reference: + // https://github.com/pytorch/pytorch/blob/c27882ffa8c1c7e4cf8ebc6c2f879e5b6c8814ad/aten/src/ATen/native/transformers/attention.cpp#L680-L681 + // cum_seq_q/k are undefined tensors for non-nested input tensors. + auto [grad_query, grad_key, grad_value] = + at::_scaled_dot_product_flash_attention_backward( + /*grad_output=*/pad_last_dim(bwd_inputs[0], 8), + /*query=*/pad_last_dim(bwd_inputs[1], 8), + /*key=*/pad_last_dim(bwd_inputs[2], 8), + /*value=*/pad_last_dim(bwd_inputs[3], 8), + /*output=*/pad_last_dim(bwd_inputs[4], 8), + /*logsumexp=*/bwd_inputs[5], + /*cum_seq_q=*/at::Tensor(), + /*cum_seq_k=*/at::Tensor(), + // Note: ATen implementation expects max_q/max_k as scalars. + /*max_q=*/bwd_inputs[1].size(2), + /*max_k=*/bwd_inputs[2].size(2), + /*dropout_p=*/dropout_p, + /*is_causal=*/is_causal, + /*philox_seed=*/philox_seed, + /*philox_offset=*/philox_offset, + /*scale=*/scale); + + // If the inputs were padded, slice the gradsto restore the original size + auto slice_last_dim = [last_dim_size](at::Tensor output) -> at::Tensor { + if (output.size(-1) != last_dim_size) { + return output; + } + return output.slice(-1, 0, last_dim_size); + }; + + // Add device dimension back to outputs. + if (first_dim_is_did) { + grad_query = grad_query.unsqueeze(0); + grad_key = grad_key.unsqueeze(0); + grad_value = grad_value.unsqueeze(0); + } + + return { + slice_last_dim(grad_query), + slice_last_dim(grad_key), + slice_last_dim(grad_value)}; +} + +std::vector EmbeddingFwdOp::evaluate( + const ExpressionEvaluator& ee, + const std::vector& inputs) const { + auto input = inputs.at(0).as(); + auto weight = inputs.at(1).as(); + auto norm_type = inputs.at(2).as(); + auto scale_grad_by_freq = inputs.at(3).as(); + auto sparse = inputs.at(4).as(); + std::optional padding_idx = std::nullopt; + if (has_padding_idx()) { + padding_idx = inputs.at(5).as(); + } + std::optional max_norm = std::nullopt; + if (has_max_norm()) { + auto idx = 5 + has_padding_idx(); + max_norm = inputs.at(idx).as(); + } + + namespace F = torch::nn::functional; + return {F::embedding( + input, + weight, + F::EmbeddingFuncOptions() + .padding_idx(padding_idx) + .max_norm(max_norm) + .norm_type(norm_type) + .scale_grad_by_freq(scale_grad_by_freq) + .sparse(sparse))}; +} + +} // namespace nvfuser diff --git a/csrc/ir/nodes.cpp b/csrc/ir/nodes.cpp index 60c58e86115..d73c92db868 100644 --- a/csrc/ir/nodes.cpp +++ b/csrc/ir/nodes.cpp @@ -70,20 +70,6 @@ std::string FullOp::toInlineString(int indent_size) const { NVF_CHECK(false, "Tensor op can not be printed inline"); } -std::vector FullOp::evaluate( - const ExpressionEvaluator& ee, - const std::vector& inputs) const { - std::vector shape; - for (auto i : c10::irange(inputs.size() - 1)) { - shape.push_back(inputs.at(i).as()); - } - DataType dtype = getFillValue()->getDataType().value(); - const auto options = - at::TensorOptions().device(at::kCUDA).dtype(data_type_to_aten(dtype)); - using namespace PolymorphicValue_functions; - return {at::full(shape, toScalar(inputs.back()), options)}; -} - NVFUSER_DEFINE_CLONE_AND_CREATE(FullOp) SelectOp::SelectOp( @@ -119,15 +105,6 @@ IterDomain* SelectOp::getIndexedID() const { .at(dim()); } -std::vector SelectOp::evaluate( - const ExpressionEvaluator& ee, - const std::vector& inputs) const { - const auto& in = inputs.at(0).as(); - int64_t dimension = dim(); - int64_t index = (int64_t)inputs.at(1); - return {in.select(dimension, index)}; -} - NVFUSER_DEFINE_CLONE_AND_CREATE(SelectOp) IndexSelectOp::IndexSelectOp( @@ -167,15 +144,6 @@ IterDomain* IndexSelectOp::getConsumerOfIndexedID() const { return ir_utils::getTvOutput(this)->getLogicalDomain().at(dim()); } -std::vector IndexSelectOp::evaluate( - const ExpressionEvaluator& ee, - const std::vector& inputs) const { - const auto& in = inputs.at(0).as(); - int64_t dimension = dim(); - const auto& indices = inputs.at(1).as().squeeze(); - return {at::index_select(in, dimension, indices)}; -} - NVFUSER_DEFINE_CLONE_AND_CREATE(IndexSelectOp) TorchGatherOp::TorchGatherOp( @@ -220,19 +188,6 @@ IterDomain* TorchGatherOp::getConsumerOfIndexedID() const { return ir_utils::getTvOutput(this)->getLogicalDomain().at(dim()); } -std::vector TorchGatherOp::evaluate( - const ExpressionEvaluator& ee, - const std::vector& inputs) const { - const auto& input = inputs.at(0).as(); - const auto& index = inputs.at(1).as(); - auto dimension = dim(); - if (exactSizes()) { - return {at::take_along_dim(input, index, dimension)}; - } else { - return {at::gather(input, dimension, index)}; - } -} - NVFUSER_DEFINE_CLONE_AND_CREATE(TorchGatherOp) ScatterOp::ScatterOp( @@ -271,16 +226,6 @@ IterDomain* ScatterOp::getIndexedID() const { return ir_utils::getTvOutput(this)->getLogicalDomain().at(dim()); } -std::vector ScatterOp::evaluate( - const ExpressionEvaluator& ee, - const std::vector& inputs) const { - const auto& input = inputs.at(0).as(); - const auto& index = inputs.at(1).as(); - const auto& src = inputs.at(2).as(); - auto dimension = dim(); - return {at::scatter(input, dimension, index, src)}; -} - NVFUSER_DEFINE_CLONE_AND_CREATE(ScatterOp) IotaOp::IotaOp( @@ -314,31 +259,6 @@ std::string IotaOp::toInlineString(int indent_size) const { NVF_CHECK(false, "Tensor op can not be printed inline"); } -std::vector IotaOp::evaluate( - const ExpressionEvaluator& ee, - const std::vector& inputs) const { - const auto options = - at::TensorOptions().device(at::kCUDA).dtype(data_type_to_aten(dtype())); - int64_t length = (int64_t)inputs.at(0); - - if (isIntegralType(dtype())) { - int64_t start = (int64_t)inputs.at(1); - int64_t step = (int64_t)inputs.at(2); - int64_t end = start + step * length; - return {at::arange(start, end, step, options)}; - } else if (isFloatingPointType(dtype())) { - double start = (double)inputs.at(1); - double step = (double)inputs.at(2); - // Due to rounding error, it can be hard to guarantee the size of - // the output of arange to be exactly length, so we generate a - // larger tensor and truncate it to length. - double end = start + step * ((double)length + 1); - return {at::arange(start, end, step, options).narrow(0, 0, length)}; - } else { - NVF_THROW("Unsupported dtype in IotaOp evaluator: ", dtype()); - } -} - NVFUSER_DEFINE_CLONE_AND_CREATE(IotaOp) EyeOp::EyeOp(IrBuilderPasskey passkey, Val* out, DataType dtype) @@ -366,19 +286,6 @@ std::string EyeOp::toString(int indent_size) const { std::string EyeOp::toInlineString(int indent_size) const { NVF_CHECK(false, "Tensor op can not be printed inline"); } -std::vector EyeOp::evaluate( - const ExpressionEvaluator& ee, - const std::vector& inputs) const { - const auto options = - at::TensorOptions().device(at::kCUDA).dtype(data_type_to_aten(dtype())); - int64_t nrows = (int64_t)inputs.at(0); - if (inputs.size() > 1) { - int64_t ncols = (int64_t)inputs.at(1); - return {at::eye(nrows, ncols, options)}; - } else { - return {at::eye(nrows, options)}; - } -} NVFUSER_DEFINE_CLONE_AND_CREATE(EyeOp) @@ -389,133 +296,6 @@ UnaryOp::UnaryOp(IrBuilderPasskey passkey, UnaryOpType type, Val* out, Val* in) addDataAttribute(type); } -std::vector UnaryOp::evaluate( - const ExpressionEvaluator& ee, - const std::vector& inputs) const { - using namespace PolymorphicValue_functions; - - const auto& in = inputs.at(0); - if (!in.hasValue()) { - return {std::monostate{}}; - } - - switch (getUnaryOpType()) { - case UnaryOpType::Neg: - return {-in}; - case UnaryOpType::Cast: - if (in.is()) { - return {PolymorphicValue( - in.as().to(data_type_to_aten(out()->dtype())))}; - } else if (isIntegralType(*out()->getDataType())) { - return {PolymorphicValue((int64_t)in)}; - } else if (isFloatingPointType(*out()->getDataType())) { - return {PolymorphicValue((double)in)}; - } else if (out()->getDataType() == DataType::Bool) { - return {PolymorphicValue((bool)in)}; - } else if (isComplexType(*out()->getDataType())) { - return {PolymorphicValue((std::complex)in)}; - } else { - NVF_THROW("dtype not supported in evaluator: ", *out()->getDataType()); - } - case UnaryOpType::Reciprocal: - return {1.0 / in}; - break; - case UnaryOpType::Abs: - return {abs(in)}; - break; - case UnaryOpType::LogicalNot: - return {!in}; - break; - case UnaryOpType::BitwiseNot: - return {~in}; - break; - case UnaryOpType::Erf: - return {erf(in)}; - break; - case UnaryOpType::ToUnsignedSmemAddr: - return {(int64_t)(unsigned)in}; - break; - case UnaryOpType::AdjustPartialLdMatrixAddrInTuring8: - case UnaryOpType::AdjustPartialLdMatrixAddrInTuring16: - return {in}; - break; - case UnaryOpType::Dereference: - if (*out()->getDataType() == DataType::Float) { - return {PolymorphicValue((double)*(float*)in)}; - } else { - NVF_THROW("dtype not supported in evaluator: ", *out()->getDataType()); - } - break; - case UnaryOpType::Sigmoid: - return {in.as().sigmoid()}; - break; - case UnaryOpType::Tanh: - return {in.as().tanh()}; - break; - case UnaryOpType::Relu: - return {at::relu(in.as())}; - break; - case UnaryOpType::Gelu: - return {at::gelu(in.as())}; - break; - case UnaryOpType::Exp: - return {at::exp(in.as())}; - break; - case UnaryOpType::Sin: - return {in.as().sin()}; - break; - case UnaryOpType::Signbit: - return {signbit(in)}; - break; - case UnaryOpType::Cos: - return {in.as().cos()}; - break; - case UnaryOpType::BitCast: - NVF_CHECK( - dataTypeSize(input(0)->dtype()) == dataTypeSize(out()->dtype()), - "BitCast only works for types of the same size"); - if (isComplexType(input(0)->dtype()) && - std::holds_alternative(out()->dtype().type)) { - // view_as_real case. - auto vec_type = std::get(out()->dtype().type); - auto inp_scalar_type = getTypeFromComplexType(input(0)->dtype()); - NVF_CHECK( - *vec_type.type == inp_scalar_type, - "Output type must be the same as the scalar type of the complex input."); - NVF_CHECK( - vec_type.size == 2, - "Expected output to be array of size 2, found array of size ", - vec_type.size); - return {in.as()}; - } else { - return {in.as().view(data_type_to_aten(out()->dtype()))}; - } - break; - case UnaryOpType::Rsqrt: - return {in.as().rsqrt()}; - break; - case UnaryOpType::Real: - return {at::real(in.as())}; - break; - case UnaryOpType::Imag: - return {at::imag(in.as())}; - break; - case UnaryOpType::Tan: - return {in.as().tan()}; - break; - case UnaryOpType::IsFinite: - return {at::isfinite(in.as())}; - break; - default: - NVF_CHECK( - false, - "Unexpected operator type ", - getUnaryOpType(), - " in ", - toString()); - } -} - void UnaryOp::printHelper(std::stringstream& ss, std::string input) const { auto op_type = getUnaryOpType(); @@ -581,102 +361,6 @@ BinaryOp::BinaryOp( addDataAttribute(type); } -std::vector BinaryOp::evaluate( - const ExpressionEvaluator& ee, - const std::vector& inputs) const { - using namespace PolymorphicValue_functions; - const auto& lhs = inputs.at(0); - const auto& rhs = inputs.at(1); - - switch (getBinaryOpType()) { - case BinaryOpType::Add: - return {lhs + rhs}; - break; - case BinaryOpType::Sub: - return {lhs - rhs}; - break; - case BinaryOpType::Mul: - return {lhs * rhs}; - break; - case BinaryOpType::Div: - return {lhs / rhs}; - break; - case BinaryOpType::Mod: - NVF_CHECK(rhs != 0); - return {lhs % rhs}; - break; - case BinaryOpType::Fmod: - NVF_CHECK(rhs != 0); - return {fmod(lhs, rhs)}; - break; - case BinaryOpType::CeilDiv: - NVF_CHECK(rhs != 0); - return {ceildiv(lhs, rhs)}; - break; - case BinaryOpType::LogicalAnd: - return {lhs && rhs}; - break; - case BinaryOpType::LogicalOr: - return {lhs || rhs}; - break; - case BinaryOpType::BitwiseAnd: - return {lhs & rhs}; - break; - case BinaryOpType::BitwiseOr: - return {lhs | rhs}; - break; - case BinaryOpType::BitwiseXor: - return {lhs ^ rhs}; - break; - case BinaryOpType::Eq: - return {eq(lhs, rhs)}; - break; - case BinaryOpType::NE: - return {ne(lhs, rhs)}; - break; - case BinaryOpType::GT: - return {gt(lhs, rhs)}; - break; - case BinaryOpType::GE: - return {ge(lhs, rhs)}; - break; - case BinaryOpType::LT: - return {lt(lhs, rhs)}; - break; - case BinaryOpType::LE: - return {le(lhs, rhs)}; - break; - case BinaryOpType::Max: - return {max(lhs, rhs)}; - break; - case BinaryOpType::Min: - return {min(lhs, rhs)}; - break; - case BinaryOpType::Gcd: - return {gcd(lhs, rhs)}; - break; - case BinaryOpType::Lshift: - return {lhs << rhs}; - break; - case BinaryOpType::Rshift: - return {lhs >> rhs}; - break; - case BinaryOpType::Complex: - return {at::complex(lhs.as(), rhs.as())}; - break; - case BinaryOpType::Pow: - return {pow(lhs, rhs)}; - break; - default: - NVF_CHECK( - false, - "Unexpected operator type: ", - getBinaryOpType(), - " in ", - toString()); - } -} - void BinaryOp::printHelper( std::stringstream& ss, int indent_size, @@ -756,38 +440,6 @@ TernaryOp::TernaryOp( addDataAttribute(type); } -std::vector TernaryOp::evaluate( - const ExpressionEvaluator& ee, - const std::vector& inputs) const { - using namespace PolymorphicValue_functions; - const auto& a = inputs.at(0); - const auto& b = inputs.at(1); - const auto& c = inputs.at(2); - switch (getTernaryOpType()) { - case TernaryOpType::Clamp: - return {std::min(std::max(a, b), c)}; - break; - case TernaryOpType::Lerp: - // This is the same lerp computed in helpers.cu - // https://math.stackexchange.com/a/1798323 - return {(c < 0.5) ? a + c * (b - a) : b - (b - a) * (1.0 - c)}; - break; - case TernaryOpType::Threshold: - return {(a <= b) ? c : a}; - break; - case TernaryOpType::Where: - return {a.as() ? b : c}; - break; - default: - NVF_CHECK( - false, - "Unexpected operator type: ", - getTernaryOpType(), - " in ", - toString()); - } -} - void TernaryOp::printHelper( std::stringstream& ss, int indent_size, @@ -888,12 +540,6 @@ std::string ArrayConstruct::toInlineString(int indent_size) const { return ss.str(); } -std::vector ArrayConstruct::evaluate( - const ExpressionEvaluator& ee, - const std::vector& inputs) const { - return {PolymorphicValue(inputs)}; -} - NVFUSER_DEFINE_CLONE_AND_CREATE(ArrayConstruct) ReverseArray::ReverseArray(IrBuilderPasskey passkey, Val* output, Val* input) @@ -935,16 +581,6 @@ std::string ReverseArray::toInlineString(int indent_size) const { return ss.str(); } -std::vector ReverseArray::evaluate( - const ExpressionEvaluator& ee, - const std::vector& inputs) const { - NVF_ERROR(inputs.size() == 1, "ReverseArray expects 1 input"); - PolymorphicValue array = inputs.at(0); - auto& vec = array.as(); - std::reverse(vec.begin(), vec.end()); - return {std::move(array)}; -} - NVFUSER_DEFINE_CLONE_AND_CREATE(ReverseArray) GetItem::GetItem(IrBuilderPasskey passkey, Val* output, Val* array, Val* index) @@ -971,13 +607,6 @@ std::string GetItem::toInlineString(int indent_size) const { return ss.str(); } -std::vector GetItem::evaluate( - const ExpressionEvaluator& ee, - const std::vector& inputs) const { - NVF_ERROR(inputs.size() == 2, "GetItem expects 2 inputs"); - return {PolymorphicValue(inputs.at(0)[inputs.at(1)])}; -} - NVFUSER_DEFINE_CLONE_AND_CREATE(GetItem) StructConstruct::StructConstruct( @@ -1033,22 +662,6 @@ std::string StructConstruct::toInlineString(int indent_size) const { return ss.str(); } -std::vector StructConstruct::evaluate( - const ExpressionEvaluator& ee, - const std::vector& inputs) const { - NVF_ERROR( - this->inputs().size() == inputs.size(), - "StructConstruct expects ", - this->inputs().size(), - " inputs"); - PolymorphicValue struct_ = - std::get(output(0)->dtype().type).create(); - for (int64_t i : c10::irange((int64_t)inputs.size())) { - struct_->*attribute(i) = inputs.at(i); - } - return {std::move(struct_)}; -} - NVFUSER_DEFINE_CLONE_AND_CREATE(StructConstruct) GetAttr::GetAttr( @@ -1079,13 +692,6 @@ std::string GetAttr::toInlineString(int indent_size) const { return ss.str(); } -std::vector GetAttr::evaluate( - const ExpressionEvaluator& ee, - const std::vector& inputs) const { - NVF_ERROR(inputs.size() == 1, "GetAttr expects 1 input"); - return {inputs.at(0)->*attr()}; -} - NVFUSER_DEFINE_CLONE_AND_CREATE(GetAttr) GetMetaData::GetMetaData(IrBuilderPasskey passkey, Val* output, Val* input) @@ -1132,14 +738,6 @@ std::string TensorConstruct::toInlineString(int indent_size) const { NVF_CHECK(false, "Tensor op can not be printed inline"); } -std::vector TensorConstruct::evaluate( - const ExpressionEvaluator& ee, - const std::vector& inputs) const { - NVF_ERROR(inputs.size() == 1, "TensorConstruct expects 1 input"); - using namespace PolymorphicValue_functions; - return {toTensor(inputs.at(0))}; -} - NVFUSER_DEFINE_CLONE_AND_CREATE(TensorConstruct) RNGOp::RNGOp( @@ -1289,26 +887,6 @@ std::string BroadcastOp::toInlineString(int indent_size) const { NVF_CHECK(false, "Tensor op can not be printed inline"); } -std::vector BroadcastOp::evaluate( - const ExpressionEvaluator& ee, - const std::vector& inputs) const { - NVF_ERROR( - inputs.size() == 1, - "BroadcastOp expects exactly 1 input, but received ", - inputs.size()); - std::vector out_shape; - const auto& in = inputs.at(0).as(); - int64_t idx = 0; - for (bool b : getBroadcastDimFlags()) { - if (b) { - out_shape.push_back(1); - } else { - out_shape.push_back(in.sizes()[idx++]); - } - } - return {in.view(out_shape)}; -} - NVFUSER_DEFINE_CLONE_AND_CREATE(BroadcastOp) SqueezeOp::SqueezeOp( @@ -1385,35 +963,6 @@ std::string SqueezeOp::toInlineString(int indent_size) const { NVF_CHECK(false, "Tensor op can not be printed inline"); } -std::vector SqueezeOp::evaluate( - const ExpressionEvaluator& ee, - const std::vector& inputs) const { - NVF_ERROR( - inputs.size() == 1, - "SqueezeOp expects exactly 1 input, but received ", - inputs.size()); - std::vector out_shape; - const auto& in = inputs.at(0).as(); - const auto& is_squeeze_dims = getSqueezeDimFlags(); - NVF_ERROR( - (int64_t)is_squeeze_dims.size() == in.dim(), - "The dimensions of input tensor and does not match with is_squeeze_dims"); - at::Tensor out = in; - for (int64_t i : c10::irange((int64_t)is_squeeze_dims.size())) { - if (is_squeeze_dims[i]) { - if (in.stride(i) == 0) { - // If the input dimension is expanded in this dimension, undo the expand - // by slicing. This ensures that any broadcast dimensions will be - // unexpanded when we do the final call to view() - out = out.slice(i, 0, 1); - } - } else { - out_shape.push_back(in.sizes()[i]); - } - } - return {out.view(out_shape)}; -} - void SqueezeOp::checkConcretization(Val* old_val, Val* new_val) const { Expr::checkConcretization(old_val, new_val); // does nullptr, vtype checks NVF_CHECK( @@ -1508,43 +1057,6 @@ std::string ReductionOp::toInlineString(int indent_size) const { NVF_CHECK(false, "Tensor op can not be printed inline"); } -std::vector ReductionOp::evaluate( - const ExpressionEvaluator& ee, - const std::vector& inputs) const { - const auto& input = inputs.at(0).as(); - const auto output = out()->as(); - - NVF_ERROR( - !output->hasRoot(), - "Evaluation for rFactored reductions is not supported."); - - std::vector reduction_axes; - for (const auto i : c10::irange(int64_t(output->getLogicalDomain().size()))) { - auto ax = output->getLogicalDomain().at(i); - if (ax->isReduction()) { - reduction_axes.push_back(i); - } - } - switch (getReductionOpType()) { - case BinaryOpType::Add: - return {at::sum(input, reduction_axes)}; - break; - case BinaryOpType::Max: - return {at::amax(input, reduction_axes)}; - break; - case BinaryOpType::Min: - return {at::amin(input, reduction_axes)}; - break; - default: - NVF_CHECK( - false, - "Unexpected operator type: ", - getReductionOpType(), - " in ", - toString()); - } -} - NVFUSER_DEFINE_CLONE_AND_CREATE(ReductionOp) GroupedReductionOp::GroupedReductionOp( @@ -1599,46 +1111,6 @@ int GroupedReductionOp::getExprIndexOfOutput(Val* output_val) const { NVF_THROW("Not an output, ", output_val->toString(), ", of ", toString()); } -std::vector GroupedReductionOp::evaluate( - const ExpressionEvaluator& ee, - const std::vector& inputs) const { - const auto num_reductions = numHorizontallyGroupedExprs(); - std::vector grouped_reduction_out; - grouped_reduction_out.reserve(num_reductions); - for (const auto i : c10::irange(num_reductions)) { - const auto& in_tensor = inputs.at(i).as(); - const auto out_tv = output(i)->as(); - NVF_ERROR( - !out_tv->hasRoot(), - "Evaluation for rFactored reductions is not supported."); - - std::vector reduction_axes; - for (const auto id : - c10::irange(int64_t(out_tv->getLogicalDomain().size()))) { - auto ax = out_tv->getLogicalDomain().at(id); - if (ax->isReduction()) { - reduction_axes.push_back(id); - } - } - switch (getReductionOpType(i)) { - case BinaryOpType::Add: - grouped_reduction_out.emplace_back(at::sum(in_tensor, reduction_axes)); - break; - case BinaryOpType::Max: - grouped_reduction_out.emplace_back(at::amax(in_tensor, reduction_axes)); - break; - default: - NVF_CHECK( - false, - "Unexpected operator type: ", - getReductionOpType(i), - " in ", - toString()); - } - } - return grouped_reduction_out; -} - NVFUSER_DEFINE_CLONE_AND_CREATE(GroupedReductionOp) std::optional WelfordTriplet::getNameOf( @@ -1820,32 +1292,6 @@ std::string WelfordOp::toInlineString(int indent_size) const { NVF_CHECK(false, "Tensor op can not be printed inline"); } -std::vector WelfordOp::evaluate( - const ExpressionEvaluator& ee, - const std::vector& inputs) const { - NVF_ERROR( - !hasInit(), - "Evaluation for WelfordOp is not implemented for non-empty initial values."); - const auto& in_tensor = inputs.at(0).as(); - const auto out_tv = out()->as(); - NVF_ERROR( - !out_tv->hasRoot(), - "Evaluation for WelfordOp is not supported when output is rFactored."); - - int64_t N = 1; - std::vector reduction_axes; - for (const auto i : c10::irange(int64_t(out_tv->getLogicalDomain().size()))) { - auto ax = out_tv->getLogicalDomain().at(i); - if (ax->isReduction()) { - reduction_axes.push_back(i); - N *= in_tensor.size(i); - } - } - const auto [in_var, in_avg] = - at::var_mean(in_tensor, reduction_axes, false, false); - return {in_avg, in_var * N, N}; -} - NVFUSER_DEFINE_CLONE_AND_CREATE(WelfordOp) GroupedWelfordOp::GroupedWelfordOp( @@ -2125,17 +1571,6 @@ std::string ExpandOp::toInlineString(int indent_size) const { NVF_CHECK(false, "Tensor op can not be printed inline"); } -std::vector ExpandOp::evaluate( - const ExpressionEvaluator& ee, - const std::vector& inputs) const { - const auto& in = inputs.at(0).as(); - std::vector expanded_size; - for (auto i : c10::irange(1, inputs.size())) { - expanded_size.push_back((int64_t)inputs.at(i)); - } - return {in.expand(expanded_size)}; -} - NVFUSER_DEFINE_CLONE_AND_CREATE(ExpandOp) RepeatOp::RepeatOp(IrBuilderPasskey passkey, TensorView* out, TensorView* in) @@ -2185,37 +1620,6 @@ std::string RepeatOp::toInlineString(int indent_size) const { NVF_CHECK(false, "Tensor op can not be printed inline"); } -std::vector RepeatOp::evaluate( - const ExpressionEvaluator& ee, - const std::vector& inputs) const { - NVF_ERROR( - inputs.size() == 1, - "RepeatOp expects exactly 1 input, but received ", - inputs.size()); - auto tensor = inputs.at(0).as(); - std::vector multipliers; - multipliers.reserve(out()->getLogicalDomain().size()); - const auto c2p = - PairwiseLogicalDomainMap(in(), out()).mapConsumerToProducer(); - for (const auto i : c10::irange(out()->getLogicalDomain().size())) { - auto out_id = out()->getLogicalDomain().at(i); - auto inp_id = c2p.at(out_id); - auto out_extent = ee.evaluate(out_id->extent()).as(); - auto inp_extent = ee.evaluate(inp_id->extent()).as(); - NVF_ERROR( - out_extent % inp_extent == 0, - "For dimension ", - i, - ", the output extent (", - out_extent, - " should be a multiple of the input extent (", - inp_extent, - ")."); - multipliers.push_back(out_extent / inp_extent); - } - return {tensor.repeat(multipliers)}; -} - NVFUSER_DEFINE_CLONE_AND_CREATE(RepeatOp) ViewAsScalar::ViewAsScalar( @@ -2241,13 +1645,6 @@ std::string ViewAsScalar::toInlineString(int indent_size) const { NVF_CHECK(false, "Tensor op can not be printed inline"); } -std::vector ViewAsScalar::evaluate( - const ExpressionEvaluator& ee, - const std::vector& inputs) const { - const at::Tensor& in = inputs.at(0).as(); - return {at::view_as_real(in)}; -} - NVFUSER_DEFINE_CLONE_AND_CREATE(ViewAsScalar) ViewOp::ViewOp(IrBuilderPasskey passkey, Val* out, Val* in) : Expr(passkey) { @@ -2274,33 +1671,6 @@ std::string ViewOp::toInlineString(int indent_size) const { NVF_CHECK(false, "Tensor op can not be printed inline"); } -std::vector ViewOp::evaluate( - const ExpressionEvaluator& ee, - const std::vector& inputs) const { - NVF_ERROR(inputs.size() == 1); - const at::Tensor& in_tensor = inputs[0].as(); - - const std::vector& out_logical = out()->getLogicalDomain(); - std::vector out_shape; - out_shape.reserve(out_logical.size()); - for (IterDomain* id : out_logical) { - if (id->isDeviceDim()) { - out_shape.push_back(1); - } else { - out_shape.push_back( - ee.evaluate(id->getMaybeExpandedExtent()).as()); - } - } - - // TODO: check allocation domain and contiguity. - - // Use `at::Tensor::reshape` instead of `at::Tensor::view` because `ViewOp` - // doesn't always produce an alias. For example, when merging an expanded - // `IterType::Broadcast` and an `IterType::Iteration`, `ViewOp` has to realize - // the expand. - return {in_tensor.reshape(out_shape)}; -} - NVFUSER_DEFINE_CLONE_AND_CREATE(ViewOp) LoadStoreOp::LoadStoreOp( @@ -2334,27 +1704,6 @@ LoadStoreOp::LoadStoreOp( addDataAttribute(cache_op); } -std::vector LoadStoreOp::evaluate( - const ExpressionEvaluator& ee, - const std::vector& inputs) const { - if (TensorView* out_tv = dynamic_cast(out())) { - if (out_tv->hasRoot()) { - std::optional> permutation = - ir_utils::computePermutation( - out_tv->getRootDomain(), out_tv->getLogicalDomain()); - NVF_ERROR( - permutation.has_value(), - "The logical domain of a Set.Permute is supposed to be a permutation of the root domain: ", - out_tv->toString()); - NVF_ERROR(inputs.size() == 1); - at::Tensor in_tensor = inputs[0].as(); - at::Tensor out_tensor = in_tensor.permute(*permutation); - return {out_tensor}; - } - } - return inputs; -} - std::string LoadStoreOp::toString(int indent_size) const { std::stringstream ss; std::string optype = load_store_type2string(opType()); @@ -4283,36 +3632,6 @@ std::pair PadOp::getPadWidths(int64_t axis) const { (*(getPadWidthInputBegin() + offset_odd))->as()); } -std::vector PadOp::evaluate( - const ExpressionEvaluator& ee, - const std::vector& inputs) const { - const auto& in = inputs.at(0).as(); - - std::vector pad_widths; - auto pad_width_offset = getPadWidthInputOffset(); - auto num_dims = in.dim(); - - for (auto i = num_dims - 1; i > -1; i--) { - auto left_pad = (int64_t)inputs.at(pad_width_offset + 2 * i); - auto right_pad = (int64_t)inputs.at(pad_width_offset + 2 * i + 1); - pad_widths.push_back(left_pad); - pad_widths.push_back(right_pad); - } - - if (isComplexType(*out()->getDataType())) { - std::complex value = - static_cast>(inputs.at(1)); - auto real = at::real(in); - auto imag = at::imag(in); - auto padded_real = at::pad(real, pad_widths, "constant", value.real()); - auto padded_imag = at::pad(imag, pad_widths, "constant", value.imag()); - return {at::complex(padded_real, padded_imag)}; - } else { - double value = static_cast(inputs.at(1)); - return {at::pad(in, pad_widths, "constant", value)}; - } -} - SliceOp::SliceOp( IrBuilderPasskey passkey, TensorView* out, @@ -4381,22 +3700,6 @@ std::vector SliceOp::getRanges() const { return ranges; } -std::vector SliceOp::evaluate( - const ExpressionEvaluator& ee, - const std::vector& inputs) const { - const auto& in = inputs.at(0).as(); - std::vector ranges; - auto ranges_offset = getRangeInputOffset(); - auto num_dims = in.dim(); - for (const auto i : c10::irange(num_dims)) { - auto start = (int64_t)inputs.at(ranges_offset + 3 * i); - auto stop = (int64_t)inputs.at(ranges_offset + 3 * i + 1); - auto step = (int64_t)inputs.at(ranges_offset + 3 * i + 2); - ranges.emplace_back(at::indexing::Slice(start, stop, step)); - } - return {in.index(ranges)}; -} - CatOp::CatOp( IrBuilderPasskey passkey, Val* out, @@ -4493,24 +3796,6 @@ Val* CatOp::getPred(int input_idx) const { return pred; } -std::vector CatOp::evaluate( - const ExpressionEvaluator& ee, - std::unordered_map& known_values) const { - // CatOp is preceded by a PadOp internally. - // For ATen evaluation, directly compute the unpadded inputs. - std::vector unpadded_inputs; - unpadded_inputs.reserve(inputs().size()); - int64_t concat_dim = concatenatedDim(); - for (Val* inp : inputs()) { - NVF_CHECK( - inp->definition() != nullptr && inp->definition()->isA(), - "Expected CatOp to be preceded by a PadOp."); - auto eval_i = ee.evaluate(inp->definition()->input(0), known_values); - unpadded_inputs.push_back(eval_i.as()); - } - return {at::cat(unpadded_inputs, concat_dim)}; -} - MatmulOp::MatmulOp(IrBuilderPasskey passkey, Val* out, Val* in_a, Val* in_b) : Expr(passkey) { addOutput(out); @@ -4532,51 +3817,6 @@ std::string MatmulOp::toInlineString(int indent_size) const { NVF_CHECK(false, "Tensor op can not be printed inline"); } -std::vector MatmulOp::evaluate( - const ExpressionEvaluator& ee, - const std::vector& inputs) const { - const auto a = inputs.at(0).as(); - const auto b = inputs.at(1).as(); - - auto matmul_out = at::matmul(a, b); - - // When the contracting dimension is sharded, each device has a partial - // matmul output and is followed by an allreduce. For loop split, this is - // represented as an rfactored reduction. The local matmul logical domain - // after the rfactor is: i{DIDx}, i{M}, i{N}, r{K//d}. Unsqueeze the - // rfactored DID axis to correctly bind with the logical domain. See - // tests/python/test_multidevice.py/test_matmul_allreduce_loop_split - auto out_logical = TensorDomain::noReductions(out()->getLogicalDomain()); - int64_t rfactor_did_idx = -1; - for (auto idx : c10::irange(static_cast(out_logical.size()))) { - if (!out_logical.at(idx)->isRFactorProduct() || - !out_logical.at(idx)->isDeviceDim()) { - continue; - } - if (rfactor_did_idx != -1) { - NVF_THROW( - "Expected only 1 rfactored DID iterdomain, found at least 2 in ", - out_logical); - } - rfactor_did_idx = idx; - } - - if (rfactor_did_idx != -1) { - matmul_out = matmul_out.unsqueeze(rfactor_did_idx); - } - - const auto& [sizes, strides] = inferShapeOfOutput(out(), ee); - auto meta_out = at::detail::empty_strided_meta(sizes, strides, a.dtype()); - - if (meta_out.is_contiguous()) { - return {matmul_out}; - } - - auto strided_matmul_out = at::empty_strided(sizes, strides, a.options()); - strided_matmul_out = strided_matmul_out.copy_(matmul_out); - return {strided_matmul_out}; -} - LinearOp::LinearOp( IrBuilderPasskey passkey, Val* out, @@ -4611,47 +3851,6 @@ std::string LinearOp::toInlineString(int indent_size) const { NVF_CHECK(false, "Tensor op can not be printed inline"); } -std::vector LinearOp::evaluate( - const ExpressionEvaluator& ee, - const std::vector& inputs) const { - const auto in = inputs.at(0).as(); - auto weight = inputs.at(1).as(); - - auto squeeze_device_dims = [](at::Tensor& t, - int64_t num_device_dims) -> void { - // Record the initial shape for the error message. - std::vector shape = t.sizes().vec(); - for ([[maybe_unused]] auto _ : c10::irange(num_device_dims)) { - NVF_CHECK( - t.size(0) == 1, - "When the weight is >2D, expect its preceding dimensions and " - "the bias's preceding dimensions to " - "be DID-parallel and therefore size-1: ", - shape); - t = t.squeeze(0); - } - }; - - // The squeezes and unsqueezes are currently required to support a sharded - // linear layer. Remove them after #2563. - auto num_device_dims = weight.dim() - 2; - squeeze_device_dims(weight, num_device_dims); - - at::Tensor out; - if (has_bias()) { - auto bias = inputs.at(2).as(); - squeeze_device_dims(bias, num_device_dims); - out = at::linear(in, weight, bias); - } else { - out = at::linear(in, weight); - } - - for ([[maybe_unused]] auto _ : c10::irange(num_device_dims)) { - out = out.unsqueeze(0); - } - return {out}; -} - SdpaFwdOp::SdpaFwdOp( IrBuilderPasskey passkey, TensorView* output, @@ -4707,106 +3906,6 @@ std::string SdpaFwdOp::toInlineString(int indent_size) const { NVF_CHECK(false, "Tensor op can not be printed inline"); } -std::vector SdpaFwdOp::evaluate( - const ExpressionEvaluator& ee, - const std::vector& inputs) const { - auto query = inputs.at(0).as(); - auto key = inputs.at(1).as(); - auto value = inputs.at(2).as(); - - const auto dropout_p = inputs.at(3).as(); - const auto is_causal = inputs.at(4).as(); - - // Temporary handling of DID parallelization see - // https://github.com/NVIDIA/Fuser/issues/2563 - bool handle_device_dim = false; - if (query.dim() == 5) { - handle_device_dim = true; - - NVF_CHECK(key.dim() == 5 && value.dim() == 5); - - auto query_domain = - TensorDomain::noReductions(this->query()->getLogicalDomain()); - auto key_domain = - TensorDomain::noReductions(this->key()->getLogicalDomain()); - auto value_domain = - TensorDomain::noReductions(this->value()->getLogicalDomain()); - NVF_CHECK( - query_domain.front()->isDeviceDim(), - "Only support DID parallelization on outermost axis"); - NVF_CHECK( - key_domain.front()->isDeviceDim(), - "Only support DID parallelization on outermost axis"); - NVF_CHECK( - value_domain.front()->isDeviceDim(), - "Only support DID parallelization on outermost axis"); - - query = query.squeeze(0); - key = key.squeeze(0); - value = value.squeeze(0); - } - - // Flash attention requires the last dimension to be padded to 8. - // https://github.com/pytorch/pytorch/blob/c27882ffa8c1c7e4cf8ebc6c2f879e5b6c8814ad/aten/src/ATen/native/transformers/attention.cpp#L675-L677 - const auto last_dim_size = query.size(-1); - auto pad_last_dim = [last_dim_size]( - at::Tensor inp, int alignment_size) -> at::Tensor { - if (last_dim_size % alignment_size == 0) { - return inp; - } - auto pad_count = alignment_size - (last_dim_size % alignment_size); - auto padded_inp = at::pad(inp, {0, pad_count}); - return padded_inp; - }; - - query = pad_last_dim(query, 8); - key = pad_last_dim(key, 8); - value = pad_last_dim(value, 8); - - // Conmpute scale using original size of last dimension - double scale = inputs.size() > 5 ? inputs.back().as() - : 1.0 / std::sqrt(last_dim_size); - - // ATen reference: - // https://github.com/pytorch/pytorch/blob/c27882ffa8c1c7e4cf8ebc6c2f879e5b6c8814ad/aten/src/ATen/native/transformers/attention.cpp#L680-L681 - auto - [output, - log_sumexp, - cum_seq_q, - cum_seq_k, - query_seq_len, - key_seq_len, - philox_seed, - philox_offset, - debug_attn_mask] = - at::_scaled_dot_product_flash_attention( - query, - key, - value, - dropout_p, - is_causal, - /*return_debug_mask=*/false, - scale); - - // If the inputs were padded, slice the output to restore the original - // size - if (output.size(-1) != last_dim_size) { - output = output.slice(-1, 0, last_dim_size); - } - - // Add back the device dim axis for output. - if (handle_device_dim) { - output = output.unsqueeze(0); - log_sumexp = log_sumexp.unsqueeze(0); - } - - // We ignore cum_seq_q/k outputs since they are undefined tensors for - // non-nested tensors. We do not store query/key_seq_len since they can be - // computed in non-nested tensor directly. debug_attn_mask is ignored - // since `return_debug_mask=false`. - return {output, log_sumexp, philox_seed, philox_offset}; -} - std::string Scope::toString(int indent_size) const { std::stringstream ss; for (auto expr : exprs()) { @@ -5246,94 +4345,6 @@ std::string SdpaBwdOp::toInlineString(int indent_size) const { NVF_CHECK(false, "Tensor op can not be printed inline"); } -std::vector SdpaBwdOp::evaluate( - const ExpressionEvaluator& ee, - const std::vector& inputs) const { - // Backward tensor inputs: grad_input, query, key, value, output, - // logsumexp, max_q/k Temporary handling of DID parallelization. See - // https://github.com/NVIDIA/Fuser/issues/2563 - bool first_dim_is_did = this->key()->as()->axis(0)->isDeviceDim(); - auto out_grad = inputs[0].as(); - if (first_dim_is_did) { - NVF_CHECK(out_grad.dim() == 5, "Expected 5D but found ", out_grad.sizes()); - } else { - NVF_CHECK(out_grad.dim() == 4, "Expected 4D but found ", out_grad.sizes()); - } - - std::vector bwd_inputs; - for (auto idx : c10::irange(6)) { - auto in_tensor = inputs.at(idx).as(); - // Removing the size 1 from sharded axis from tensors. - if (first_dim_is_did) { - in_tensor = in_tensor.squeeze(0); - } - bwd_inputs.push_back(in_tensor); - } - const auto dropout_p = inputs.at(6).as(); - const auto is_causal = inputs.at(7).as(); - const auto philox_seed = inputs.at(8).as(); - const auto philox_offset = inputs.at(9).as(); - - // Flash attention requires the last dimension to be padded to 8. - // https://github.com/pytorch/pytorch/blob/c27882ffa8c1c7e4cf8ebc6c2f879e5b6c8814ad/aten/src/ATen/native/transformers/attention.cpp#L675-L677 - const auto last_dim_size = bwd_inputs[0].size(-1); - auto pad_last_dim = [last_dim_size]( - at::Tensor inp, int alignment_size) -> at::Tensor { - if (last_dim_size % alignment_size == 0) { - return inp; - } - auto pad_count = alignment_size - (last_dim_size % alignment_size); - auto padded_inp = at::pad(inp, {0, pad_count}); - return padded_inp; - }; - - // Conmpute scale using original size of last dimension - double scale = inputs.size() > 10 ? inputs.back().as() - : 1.0 / std::sqrt(last_dim_size); - - // ATen reference: - // https://github.com/pytorch/pytorch/blob/c27882ffa8c1c7e4cf8ebc6c2f879e5b6c8814ad/aten/src/ATen/native/transformers/attention.cpp#L680-L681 - // cum_seq_q/k are undefined tensors for non-nested input tensors. - auto [grad_query, grad_key, grad_value] = - at::_scaled_dot_product_flash_attention_backward( - /*grad_output=*/pad_last_dim(bwd_inputs[0], 8), - /*query=*/pad_last_dim(bwd_inputs[1], 8), - /*key=*/pad_last_dim(bwd_inputs[2], 8), - /*value=*/pad_last_dim(bwd_inputs[3], 8), - /*output=*/pad_last_dim(bwd_inputs[4], 8), - /*logsumexp=*/bwd_inputs[5], - /*cum_seq_q=*/at::Tensor(), - /*cum_seq_k=*/at::Tensor(), - // Note: ATen implementation expects max_q/max_k as scalars. - /*max_q=*/bwd_inputs[1].size(2), - /*max_k=*/bwd_inputs[2].size(2), - /*dropout_p=*/dropout_p, - /*is_causal=*/is_causal, - /*philox_seed=*/philox_seed, - /*philox_offset=*/philox_offset, - /*scale=*/scale); - - // If the inputs were padded, slice the gradsto restore the original size - auto slice_last_dim = [last_dim_size](at::Tensor output) -> at::Tensor { - if (output.size(-1) != last_dim_size) { - return output; - } - return output.slice(-1, 0, last_dim_size); - }; - - // Add device dimension back to outputs. - if (first_dim_is_did) { - grad_query = grad_query.unsqueeze(0); - grad_key = grad_key.unsqueeze(0); - grad_value = grad_value.unsqueeze(0); - } - - return { - slice_last_dim(grad_query), - slice_last_dim(grad_key), - slice_last_dim(grad_value)}; -} - EmbeddingFwdOp::EmbeddingFwdOp( IrBuilderPasskey passkey, TensorView* output, @@ -5395,33 +4406,4 @@ std::string EmbeddingFwdOp::toInlineString(int indent_size) const { NVF_CHECK(false, "Tensor op can not be printed inline"); } -std::vector EmbeddingFwdOp::evaluate( - const ExpressionEvaluator& ee, - const std::vector& inputs) const { - auto input = inputs.at(0).as(); - auto weight = inputs.at(1).as(); - auto norm_type = inputs.at(2).as(); - auto scale_grad_by_freq = inputs.at(3).as(); - auto sparse = inputs.at(4).as(); - std::optional padding_idx = std::nullopt; - if (has_padding_idx()) { - padding_idx = inputs.at(5).as(); - } - std::optional max_norm = std::nullopt; - if (has_max_norm()) { - auto idx = 5 + has_padding_idx(); - max_norm = inputs.at(idx).as(); - } - - namespace F = torch::nn::functional; - return {F::embedding( - input, - weight, - F::EmbeddingFuncOptions() - .padding_idx(padding_idx) - .max_norm(max_norm) - .norm_type(norm_type) - .scale_grad_by_freq(scale_grad_by_freq) - .sparse(sparse))}; -} } // namespace nvfuser From cfce8de3e09cd62a87de7544d5c191405d2185fd Mon Sep 17 00:00:00 2001 From: Christian Sarofeen Date: Wed, 29 Jan 2025 13:37:23 -0800 Subject: [PATCH 02/16] Missed a line change. --- csrc/ir/nodes.cpp | 1 - 1 file changed, 1 deletion(-) diff --git a/csrc/ir/nodes.cpp b/csrc/ir/nodes.cpp index d73c92db868..74ca3977744 100644 --- a/csrc/ir/nodes.cpp +++ b/csrc/ir/nodes.cpp @@ -19,7 +19,6 @@ #include #include #include -#include #include #include #include From 5cfad143d80ed55ab5d7c82b33fe4f59e586fd3d Mon Sep 17 00:00:00 2001 From: Christian Sarofeen Date: Wed, 29 Jan 2025 13:48:23 -0800 Subject: [PATCH 03/16] Forgot license. --- csrc/ir/evaluate.cpp | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/csrc/ir/evaluate.cpp b/csrc/ir/evaluate.cpp index 3d6f7ca63a6..df1abba965a 100644 --- a/csrc/ir/evaluate.cpp +++ b/csrc/ir/evaluate.cpp @@ -1,3 +1,11 @@ +// clang-format off +/* + * SPDX-FileCopyrightText: Copyright (c) 2023-present NVIDIA CORPORATION & AFFILIATES. + * All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + */ +// clang-format on + #include #include #include From 1fdbb400a7f906a79aa4f1bedbb1b6ea049fc67e Mon Sep 17 00:00:00 2001 From: Christian Sarofeen Date: Wed, 29 Jan 2025 13:49:19 -0800 Subject: [PATCH 04/16] Move expression evaluator executor to its own file. --- CMakeLists.txt | 1 + csrc/runtime/executor.cpp | 66 ---------------------- csrc/runtime/executor.h | 29 ---------- csrc/runtime/executor_dispatch.cpp | 1 + csrc/runtime/expr_eval_exec.cpp | 88 ++++++++++++++++++++++++++++++ csrc/runtime/expr_eval_exec.h | 42 ++++++++++++++ tests/cpp/test_alias.cpp | 1 + 7 files changed, 133 insertions(+), 95 deletions(-) create mode 100644 csrc/runtime/expr_eval_exec.cpp create mode 100644 csrc/runtime/expr_eval_exec.h diff --git a/CMakeLists.txt b/CMakeLists.txt index 5e0a6300370..67f252f1a96 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -210,6 +210,7 @@ list(APPEND NVFUSER_SRCS ${NVFUSER_SRCS_DIR}/runtime/executor_kernel_arg.cpp ${NVFUSER_SRCS_DIR}/runtime/executor_params.cpp ${NVFUSER_SRCS_DIR}/runtime/executor_utils.cpp + ${NVFUSER_SRCS_DIR}/runtime/expr_eval_exec.cpp ${NVFUSER_SRCS_DIR}/runtime/fusion_cache_utils.cpp ${NVFUSER_SRCS_DIR}/runtime/fusion_executor_cache.cpp ${NVFUSER_SRCS_DIR}/runtime/fusion_kernel_runtime.cpp diff --git a/csrc/runtime/executor.cpp b/csrc/runtime/executor.cpp index 664c2b02199..446eb271468 100644 --- a/csrc/runtime/executor.cpp +++ b/csrc/runtime/executor.cpp @@ -53,72 +53,6 @@ std::unique_ptr& KernelExecutor:: return evaluator_precomputed_values_; } -bool ExprEvalExecutor::supported(Fusion* fusion) { - FUSER_PERF_SCOPE("ExprEvalExecutor::supported"); - return std::all_of( - fusion->outputs().begin(), fusion->outputs().end(), [&fusion](Val* out) { - return fusion->getOutputAlias(out).type == AllocationType::Evaluate; - }); -} - -void ExprEvalExecutor::compile(Fusion* fusion) { - FUSER_PERF_SCOPE("ExprEvalExecutor::compile"); - if (isProfilerEnabled()) { - FusionProfiler::segment(group_id_).startCompile(); - } - NVF_ERROR( - supported(fusion), - "ExprEvalExecutor does not support the Fusion provided."); - fusion_ = std::make_unique(*fusion); - if (isProfilerEnabled()) { - FusionProfiler::segment(group_id_).stopCompile(); - } -} - -bool ExprEvalExecutor::isCompiled() const { - return fusion_ != nullptr; -} - -std::vector ExprEvalExecutor::run( - KernelArgumentHolder& args, - std::vector outputs) { - FUSER_PERF_SCOPE("ExprEvalExecutor::run"); - - if (isProfilerEnabled()) { - NVF_CHECK( - group_id_ >= 0, - "An invalid segment id is passed to FusionProfiler!:", - group_id_); - SegmentProfiler& sprof = FusionProfiler::segment(group_id_); - sprof.inputBytesAccessed(computeBytes(args)); - sprof.scheduler(toString(SchedulerType::ExprEval)); - sprof.startKernel(); - } - - NVF_ERROR(fusion_, "Need to compile before you can run."); - // Bind fusion inputs - auto expr_eval = executor_utils::bindInputs(args, fusion_.get()); - { - NVF_ERROR( - outputs.empty(), - "Fusion executor is using expression evaluator,", - " and expects that the outputs are not populated, which they were."); - if (outputs.empty()) { - for (const auto& out_val : fusion_->outputs()) { - auto out_tensor = - expr_eval.evaluate(out_val->as()).as(); - expr_eval.bind(out_val, out_tensor); - outputs.emplace_back(out_tensor); - } - } - } - if (isProfilerEnabled()) { - FusionProfiler::segment(group_id_).stopKernel(); - FusionProfiler::segment(group_id_).setDevice(args.getDeviceIndex()); - } - return outputs; -} - namespace { bool hasCpuScalarOutputs(Fusion* _fusion) { if (_fusion->exprs().empty()) { diff --git a/csrc/runtime/executor.h b/csrc/runtime/executor.h index ba5d1c58c40..d6b365a3ac1 100644 --- a/csrc/runtime/executor.h +++ b/csrc/runtime/executor.h @@ -29,35 +29,6 @@ namespace nvfuser { -class ExprEvalExecutor : public ExecutorAbstract { - public: - ExprEvalExecutor( - int64_t fusion_id = 0, - int64_t concrete_id = 0, - int64_t runtime_id = 0, - int64_t group_id = 0) - : ExecutorAbstract(fusion_id, concrete_id, runtime_id, group_id) {} - - // Returns true if all fusion outputs are expression evaluated. - static bool supported(Fusion* fusion); - - void compile(Fusion* fusion); - - bool isCompiled() const override; - - NVF_API std::vector run( - KernelArgumentHolder& args, - std::vector outputs = {}); - - const std::unique_ptr& fusion() { - return fusion_; - } - - private: - // TODO: Set properly - std::unique_ptr fusion_; -}; - class KernelExecutor : public ExecutorAbstract { public: // NVF_API was added for nvfuser_extension. See examples/sinh_extension. diff --git a/csrc/runtime/executor_dispatch.cpp b/csrc/runtime/executor_dispatch.cpp index 8012a33af0a..e0e74c38d58 100644 --- a/csrc/runtime/executor_dispatch.cpp +++ b/csrc/runtime/executor_dispatch.cpp @@ -10,6 +10,7 @@ #include #include +#include #include diff --git a/csrc/runtime/expr_eval_exec.cpp b/csrc/runtime/expr_eval_exec.cpp new file mode 100644 index 00000000000..333c9458732 --- /dev/null +++ b/csrc/runtime/expr_eval_exec.cpp @@ -0,0 +1,88 @@ +// clang-format off +/* + * SPDX-FileCopyrightText: Copyright (c) 2023-present NVIDIA CORPORATION & AFFILIATES. + * All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + */ +// clang-format on + +#include + +#include +#include + +namespace nvfuser { + +bool ExprEvalExecutor::supported(Fusion* fusion) { + FUSER_PERF_SCOPE("ExprEvalExecutor::supported"); + return std::all_of( + fusion->outputs().begin(), fusion->outputs().end(), [&fusion](Val* out) { + return fusion->getOutputAlias(out).type == AllocationType::Evaluate; + }); +} + +void ExprEvalExecutor::compile(Fusion* fusion) { + FUSER_PERF_SCOPE("ExprEvalExecutor::compile"); + if (isProfilerEnabled()) { + FusionProfiler::segment(group_id_).startCompile(); + } + NVF_ERROR( + supported(fusion), + "ExprEvalExecutor does not support the Fusion provided."); + fusion_ = std::make_unique(*fusion); + if (isProfilerEnabled()) { + FusionProfiler::segment(group_id_).stopCompile(); + } +} + +bool ExprEvalExecutor::isCompiled() const { + return fusion_ != nullptr; +} + +std::vector ExprEvalExecutor::run( + KernelArgumentHolder& args, + std::vector outputs) { + FUSER_PERF_SCOPE("ExprEvalExecutor::run"); + + if (isProfilerEnabled()) { + NVF_CHECK( + group_id_ >= 0, + "An invalid segment id is passed to FusionProfiler!:", + group_id_); + SegmentProfiler& sprof = FusionProfiler::segment(group_id_); + sprof.inputBytesAccessed(computeBytes(args)); + sprof.scheduler(toString(SchedulerType::ExprEval)); + sprof.startKernel(); + } + + NVF_ERROR(fusion_, "Need to compile before you can run."); + // Bind fusion inputs + ExpressionEvaluator expr_eval; + + { + FUSER_PERF_SCOPE("ExprEvalExecutor::bindInputs"); + expr_eval = executor_utils::bindInputs(args, fusion_.get()); + } + { + FUSER_PERF_SCOPE("ExprEvalExecutor::Eval"); + NVF_ERROR( + outputs.empty(), + "Fusion executor is using expression evaluator,", + " and expects that the outputs are not populated, which they were."); + if (outputs.empty()) { + for (const auto& out_val : fusion_->outputs()) { + auto out_tensor = + expr_eval.evaluate(out_val->as()).as(); + expr_eval.bind(out_val, out_tensor); + outputs.emplace_back(out_tensor); + } + } + } + if (isProfilerEnabled()) { + FusionProfiler::segment(group_id_).stopKernel(); + FusionProfiler::segment(group_id_).setDevice(args.getDeviceIndex()); + } + return outputs; +} + +} // namespace nvfuser diff --git a/csrc/runtime/expr_eval_exec.h b/csrc/runtime/expr_eval_exec.h new file mode 100644 index 00000000000..67310b19767 --- /dev/null +++ b/csrc/runtime/expr_eval_exec.h @@ -0,0 +1,42 @@ +// clang-format off +/* + * SPDX-FileCopyrightText: Copyright (c) 2023-present NVIDIA CORPORATION & AFFILIATES. + * All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + */ +// clang-format on +#pragma once +#include +#include + +namespace nvfuser { + +class ExprEvalExecutor : public ExecutorAbstract { + public: + ExprEvalExecutor( + int64_t fusion_id = 0, + int64_t concrete_id = 0, + int64_t runtime_id = 0, + int64_t group_id = 0) + : ExecutorAbstract(fusion_id, concrete_id, runtime_id, group_id) {} + + // Returns true if all fusion outputs are expression evaluated. + static bool supported(Fusion* fusion); + + void compile(Fusion* fusion); + + bool isCompiled() const override; + + NVF_API std::vector run( + KernelArgumentHolder& args, + std::vector outputs = {}); + + const std::unique_ptr& fusion() { + return fusion_; + } + + private: + // TODO: Set properly + std::unique_ptr fusion_; +}; +} // namespace nvfuser diff --git a/tests/cpp/test_alias.cpp b/tests/cpp/test_alias.cpp index 87bb95e474b..d80bcf90593 100644 --- a/tests/cpp/test_alias.cpp +++ b/tests/cpp/test_alias.cpp @@ -18,6 +18,7 @@ #include #include #include +#include #include #include #include From d9a627ae9bafeca909a749a63b5244887c2a5bca Mon Sep 17 00:00:00 2001 From: Christian Sarofeen Date: Thu, 30 Jan 2025 17:25:55 -0800 Subject: [PATCH 05/16] Add fast execution for permute. Add temporary timing utilities. --- csrc/expr_evaluator.cpp | 2 + csrc/ir/evaluate.cpp | 3 + csrc/runtime/expr_eval_exec.cpp | 134 +++++++++++++++++++++++++++++--- csrc/runtime/expr_eval_exec.h | 21 ++++- tests/cpp/test_evaluator.cpp | 17 ++-- 5 files changed, 158 insertions(+), 19 deletions(-) diff --git a/csrc/expr_evaluator.cpp b/csrc/expr_evaluator.cpp index a2ebccfb7b3..d8f854295bc 100644 --- a/csrc/expr_evaluator.cpp +++ b/csrc/expr_evaluator.cpp @@ -134,6 +134,7 @@ void ExpressionEvaluator::bindTensorDomain( const TensorView* tv, const at::Tensor& t, const bool evaluate_validate) { + FUSER_PERF_SCOPE("ExpressionEvaluator::bindTensorDomain"); auto logical_domain = TensorDomain::noReductions(tv->getLogicalDomain()); NVF_ERROR( t.dim() == (int64_t)logical_domain.size(), @@ -178,6 +179,7 @@ void ExpressionEvaluator::bind_( const Val* value, PolymorphicValue concrete_value, bool evaluate_validate) { + FUSER_PERF_SCOPE("ExpressionEvaluator::bind_"); using namespace PolymorphicValue_functions; NVF_CHECK(concrete_value.hasValue(), "Cannot bind to undefined value"); if (value->isConst()) { diff --git a/csrc/ir/evaluate.cpp b/csrc/ir/evaluate.cpp index df1abba965a..04ad9132285 100644 --- a/csrc/ir/evaluate.cpp +++ b/csrc/ir/evaluate.cpp @@ -685,6 +685,9 @@ std::vector ViewOp::evaluate( // doesn't always produce an alias. For example, when merging an expanded // `IterType::Broadcast` and an `IterType::Iteration`, `ViewOp` has to realize // the expand. + if (in_tensor.is_contiguous()) { + return {in_tensor.view(out_shape)}; + } return {in_tensor.reshape(out_shape)}; } diff --git a/csrc/runtime/expr_eval_exec.cpp b/csrc/runtime/expr_eval_exec.cpp index 333c9458732..e53ebc13878 100644 --- a/csrc/runtime/expr_eval_exec.cpp +++ b/csrc/runtime/expr_eval_exec.cpp @@ -11,6 +11,8 @@ #include #include +#include + namespace nvfuser { bool ExprEvalExecutor::supported(Fusion* fusion) { @@ -30,9 +32,18 @@ void ExprEvalExecutor::compile(Fusion* fusion) { supported(fusion), "ExprEvalExecutor does not support the Fusion provided."); fusion_ = std::make_unique(*fusion); + exprs_ = fusion_->exprs(); + for (auto expr : exprs_) { + if (expr->isA()) { + compile(expr->as()); + } else if (expr->isA()) { + compile(expr->as()); + } + } if (isProfilerEnabled()) { FusionProfiler::segment(group_id_).stopCompile(); } + cudaProfilerStart(); } bool ExprEvalExecutor::isCompiled() const { @@ -44,6 +55,11 @@ std::vector ExprEvalExecutor::run( std::vector outputs) { FUSER_PERF_SCOPE("ExprEvalExecutor::run"); + NVF_ERROR( + outputs.empty(), + "Fusion executor is using expression evaluator,", + " and expects that the outputs are not populated, which they were."); + if (isProfilerEnabled()) { NVF_CHECK( group_id_ >= 0, @@ -65,17 +81,26 @@ std::vector ExprEvalExecutor::run( } { FUSER_PERF_SCOPE("ExprEvalExecutor::Eval"); - NVF_ERROR( - outputs.empty(), - "Fusion executor is using expression evaluator,", - " and expects that the outputs are not populated, which they were."); - if (outputs.empty()) { - for (const auto& out_val : fusion_->outputs()) { - auto out_tensor = - expr_eval.evaluate(out_val->as()).as(); - expr_eval.bind(out_val, out_tensor); - outputs.emplace_back(out_tensor); + + for (auto expr : exprs_) { + if (ViewOp* view = dynamic_cast(expr)) { + auto output_tensor = + run(view, expr_eval.evaluate(view->in()).as()); + expr_eval.bind(view->out(), output_tensor); + continue; + } else if (LoadStoreOp* ld_st_op = dynamic_cast(expr)) { + auto output_tensor = + run(ld_st_op, expr_eval.evaluate(ld_st_op->in()).as()); + expr_eval.bind(ld_st_op->out(), output_tensor); + continue; } + expr_eval.evaluate(expr->outputs()[0]); + } + + for (const auto& out_val : fusion_->outputs()) { + auto out_tensor = expr_eval.evaluate(out_val).as(); + // expr_eval.bind(out_val, out_tensor); + outputs.emplace_back(out_tensor); } } if (isProfilerEnabled()) { @@ -85,4 +110,93 @@ std::vector ExprEvalExecutor::run( return outputs; } +namespace { +bool isContiguous(TensorView* tv) { + auto logical = TensorDomain::noReductions(tv->getLogicalDomain()); + auto alloc = TensorDomain::noReductions(tv->getMaybeAllocationDomain()); + if (logical.size() != alloc.size()) { + return false; + } + for (int64_t id_i : c10::irange(logical.size())) { + if (logical[id_i]->isBroadcast() && alloc[id_i]->isBroadcast()) { + if (logical[id_i]->hasExpandedExtent()) { + return false; + } + continue; + } + if (logical[id_i] != alloc[id_i]) { + return false; + } + if (!tv->getContiguity()[id_i]) { + return false; + } + } + return true; +} +} // namespace + +void ExprEvalExecutor::compile(ViewOp* view_op) { + FUSER_PERF_SCOPE("ExprEvalExecutor::compile(ViewOp* view_op"); + std::vector sizes; + bool neg_1_found = false; + for (auto id : view_op->out()->getLogicalDomain()) { + // Ignore sharded dimensions + if (id->isDeviceDim()) { + sizes.push_back(1); + continue; + } + + // Constant reshape specified dimensions + auto id_size = id->getMaybeExpandedExtent(); + if (id_size->isConstInt()) { + sizes.push_back(id_size->evaluate().as()); + continue; + } + + NVF_ERROR( + !neg_1_found, + "Invalid reshape op found, more than one unknown dimensions size specified."); + + // Only one free variable allowed + sizes.push_back(-1); + neg_1_found = true; + } + output_view_sizes[view_op] = sizes; + + use_view[view_op] = isContiguous(view_op->in()); +} + +at::Tensor ExprEvalExecutor::run(ViewOp* view_op, at::Tensor input) { + FUSER_PERF_SCOPE("ExprEvalExecutor::run(ViewOp* view_op"); + if (use_view[view_op]) { + return input.view(output_view_sizes[view_op]); + } + return input.reshape(output_view_sizes[view_op]); +} + +void ExprEvalExecutor::compile(LoadStoreOp* ld_st_op) { + FUSER_PERF_SCOPE("ExprEvalExecutor::compile(LoadStoreOp* ld_st_op"); + if (TensorView* out_tv = dynamic_cast(ld_st_op->out())) { + if (out_tv->hasRoot()) { + std::optional> permutation = + ir_utils::computePermutation( + out_tv->getRootDomain(), out_tv->getLogicalDomain()); + NVF_ERROR( + permutation.has_value(), + "The logical domain of a Set.Permute is supposed to be a permutation of the root domain: ", + out_tv->toString()); + permutation_orders[ld_st_op] = *permutation; + } + } +} + +at::Tensor ExprEvalExecutor::run(LoadStoreOp* ld_st_op, at::Tensor input) { + FUSER_PERF_SCOPE("ExprEvalExecutor::run(LoadStoreOp* ld_st_op"); + auto permute_it = permutation_orders.find(ld_st_op); + if (permute_it == permutation_orders.end()) { + return input; + } + return input.permute(permute_it->second); +} + } // namespace nvfuser diff --git a/csrc/runtime/expr_eval_exec.h b/csrc/runtime/expr_eval_exec.h index 67310b19767..961bf5e2041 100644 --- a/csrc/runtime/expr_eval_exec.h +++ b/csrc/runtime/expr_eval_exec.h @@ -6,6 +6,7 @@ */ // clang-format on #pragma once + #include #include @@ -36,7 +37,25 @@ class ExprEvalExecutor : public ExecutorAbstract { } private: - // TODO: Set properly std::unique_ptr fusion_; + + // Expressions to evaluate + std::vector exprs_; + + // Sizes of the output of view ops, only one value can be unknown at it gets + // processed in aten as a -1 size, every other dim is a constant positive + // integer value. + std::unordered_map> output_view_sizes; + // Indicates if it's safe to use at::view instead of at::reshape + std::unordered_map use_view; + + // Permute map, stores permutation axes if a LoadStoreOp requires them. + std::unordered_map> permutation_orders; + + void compile(ViewOp* view_op); + at::Tensor run(ViewOp* view_op, at::Tensor input); + + void compile(LoadStoreOp* ld_st_op); + at::Tensor run(LoadStoreOp* ld_st_op, at::Tensor input); }; } // namespace nvfuser diff --git a/tests/cpp/test_evaluator.cpp b/tests/cpp/test_evaluator.cpp index 05276720cd1..94ba1e540f4 100644 --- a/tests/cpp/test_evaluator.cpp +++ b/tests/cpp/test_evaluator.cpp @@ -15,6 +15,7 @@ #include #include #include +#include namespace nvfuser { @@ -594,15 +595,15 @@ TEST_F(ExprEvalTest, ReshapePermuteReshape) { out = reshape(out, {IrBuilder::create(6), size(out, 2)}); fusion.addOutput(out); + fusion.aliasOutputToInput(out, in, AllocationType::Evaluate); at::Tensor in_tensor = at::rand({72}).cuda().as_strided({9, 6}, {8, 1}); - - ExpressionEvaluator evaluator; - evaluator.bind(in, in_tensor); - at::Tensor out_tensor = evaluator.evaluate(out).as(); - - EXPECT_EQ(in_tensor.data_ptr(), out_tensor.data_ptr()); - EXPECT_THAT(out_tensor.sizes(), ElementsAre(6, 9)); - EXPECT_THAT(out_tensor.strides(), ElementsAre(1, 8)); + ExprEvalExecutor eee; + eee.compile(&fusion); + auto args = KernelArgumentHolder::createKernelArgumentHolder({in_tensor}); + auto outs = eee.run(args); + EXPECT_EQ(in_tensor.data_ptr(), outs[0].data_ptr()); + EXPECT_THAT(outs[0].sizes(), ElementsAre(6, 9)); + EXPECT_THAT(outs[0].strides(), ElementsAre(1, 8)); } TEST_F(ExprEvalTest, Reshape_ForwardBroadcast) { From cf4d3f81d27c18e7209feed7f22cebaeaa22655a Mon Sep 17 00:00:00 2001 From: Christian Sarofeen Date: Sat, 1 Feb 2025 01:06:30 +0000 Subject: [PATCH 06/16] Debugging input binding --- csrc/expr_evaluator.h | 7 +++++ csrc/runtime/executor_utils.cpp | 8 ++--- csrc/runtime/expr_eval_exec.cpp | 53 +++++++++++++++++---------------- tests/cpp/test_evaluator.cpp | 4 +++ 4 files changed, 43 insertions(+), 29 deletions(-) diff --git a/csrc/expr_evaluator.h b/csrc/expr_evaluator.h index b6c8e1857ea..0a79defaaf8 100644 --- a/csrc/expr_evaluator.h +++ b/csrc/expr_evaluator.h @@ -22,6 +22,7 @@ namespace nvfuser { class PrecomputedValues; +class ExprEvalExecutor; //! Calculate Fusion IR expressions class ExpressionEvaluator { @@ -91,6 +92,12 @@ class ExpressionEvaluator { ExpressionEvaluator clone(IrCloner& ir_cloner) const; + protected: + friend ExprEvalExecutor; + // Direct access to adding values to known_values_ without going through bind_ + // which does validation and will also bind all tensor domain information. + void unsafeBind(const Val* value, PolymorphicValue concrete_value); + private: void bind_( const Val* value, diff --git a/csrc/runtime/executor_utils.cpp b/csrc/runtime/executor_utils.cpp index 4070b6e2683..f658104920f 100644 --- a/csrc/runtime/executor_utils.cpp +++ b/csrc/runtime/executor_utils.cpp @@ -572,17 +572,17 @@ void validateVectorizedTensors( ExpressionEvaluator bindInputs( const KernelArgumentHolder& args, - Fusion* kernel) { + Fusion* fusion) { FUSER_PERF_SCOPE("executor_utils::bindInputs"); // args may contains more than just inputs, but inputs are always at the // beginning. NVF_ERROR( - kernel->inputs().size() <= args.size(), - "KernelArgumentHolder contains less argument than kernel's input."); + fusion->inputs().size() <= args.size(), + "KernelArgumentHolder contains less argument than fusion's input."); ExpressionEvaluator expr_eval; - const auto& inputs = kernel->inputs(); + const auto& inputs = fusion->inputs(); for (const auto i : c10::irange(inputs.size())) { // NOTE: we bind all inputs here, including at::Tensors. This means that // expr_eval will create a PolymorphicValue containing *args[i], which means diff --git a/csrc/runtime/expr_eval_exec.cpp b/csrc/runtime/expr_eval_exec.cpp index e53ebc13878..094f0d683b1 100644 --- a/csrc/runtime/expr_eval_exec.cpp +++ b/csrc/runtime/expr_eval_exec.cpp @@ -33,13 +33,13 @@ void ExprEvalExecutor::compile(Fusion* fusion) { "ExprEvalExecutor does not support the Fusion provided."); fusion_ = std::make_unique(*fusion); exprs_ = fusion_->exprs(); - for (auto expr : exprs_) { - if (expr->isA()) { - compile(expr->as()); - } else if (expr->isA()) { - compile(expr->as()); - } - } + // for (auto expr : exprs_) { + // if (expr->isA()) { + // compile(expr->as()); + // } else if (expr->isA()) { + // compile(expr->as()); + // } + // } if (isProfilerEnabled()) { FusionProfiler::segment(group_id_).stopCompile(); } @@ -74,32 +74,35 @@ std::vector ExprEvalExecutor::run( NVF_ERROR(fusion_, "Need to compile before you can run."); // Bind fusion inputs ExpressionEvaluator expr_eval; - { FUSER_PERF_SCOPE("ExprEvalExecutor::bindInputs"); - expr_eval = executor_utils::bindInputs(args, fusion_.get()); + // expr_eval = executor_utils::bindInputs(args, fusion_.get()); + NVF_ERROR( + fusion_->inputs().size() <= args.size(), + "KernelArgumentHolder contains less argument than fusion's input."); + for(auto inp_i : c10::irange(fusion_->inputs().size())){ + expr_eval.unsafeBind(fusion_->inputs()[inp_i], *args[inp_i]); + } } { FUSER_PERF_SCOPE("ExprEvalExecutor::Eval"); - - for (auto expr : exprs_) { - if (ViewOp* view = dynamic_cast(expr)) { - auto output_tensor = - run(view, expr_eval.evaluate(view->in()).as()); - expr_eval.bind(view->out(), output_tensor); - continue; - } else if (LoadStoreOp* ld_st_op = dynamic_cast(expr)) { - auto output_tensor = - run(ld_st_op, expr_eval.evaluate(ld_st_op->in()).as()); - expr_eval.bind(ld_st_op->out(), output_tensor); - continue; - } - expr_eval.evaluate(expr->outputs()[0]); - } + // for (auto expr : exprs_) { + // if (ViewOp* view = dynamic_cast(expr)) { + // auto output_tensor = + // run(view, expr_eval.evaluate(view->in()).as()); + // expr_eval.bind(view->out(), output_tensor); + // continue; + // } else if (LoadStoreOp* ld_st_op = dynamic_cast(expr)) { + // auto output_tensor = + // run(ld_st_op, expr_eval.evaluate(ld_st_op->in()).as()); + // expr_eval.bind(ld_st_op->out(), output_tensor); + // continue; + // } + // expr_eval.evaluate(expr->outputs()[0]); + // } for (const auto& out_val : fusion_->outputs()) { auto out_tensor = expr_eval.evaluate(out_val).as(); - // expr_eval.bind(out_val, out_tensor); outputs.emplace_back(out_tensor); } } diff --git a/tests/cpp/test_evaluator.cpp b/tests/cpp/test_evaluator.cpp index 94ba1e540f4..c4b4e546a3f 100644 --- a/tests/cpp/test_evaluator.cpp +++ b/tests/cpp/test_evaluator.cpp @@ -601,6 +601,10 @@ TEST_F(ExprEvalTest, ReshapePermuteReshape) { eee.compile(&fusion); auto args = KernelArgumentHolder::createKernelArgumentHolder({in_tensor}); auto outs = eee.run(args); + for (auto i : c10::irange(99)) { + (void)i; + eee.run(args); + } EXPECT_EQ(in_tensor.data_ptr(), outs[0].data_ptr()); EXPECT_THAT(outs[0].sizes(), ElementsAre(6, 9)); EXPECT_THAT(outs[0].strides(), ElementsAre(1, 8)); From 398e7a5e07b351cb1bb9d03d1c5ea64511a768b7 Mon Sep 17 00:00:00 2001 From: Christian Sarofeen Date: Sun, 2 Feb 2025 07:04:26 -0800 Subject: [PATCH 07/16] Remove recursive binding of tensors. --- csrc/expr_evaluator.cpp | 16 ++++++++++- csrc/ir/evaluate.cpp | 18 +++++++++++-- csrc/runtime/expr_eval_exec.cpp | 48 ++++++++++++++++----------------- 3 files changed, 55 insertions(+), 27 deletions(-) diff --git a/csrc/expr_evaluator.cpp b/csrc/expr_evaluator.cpp index d8f854295bc..4e580696f79 100644 --- a/csrc/expr_evaluator.cpp +++ b/csrc/expr_evaluator.cpp @@ -175,6 +175,12 @@ void ExpressionEvaluator::bindTensorDomain( } } +void ExpressionEvaluator::unsafeBind( + const Val* value, + PolymorphicValue concrete_value) { + known_values_[value] = concrete_value; +} + void ExpressionEvaluator::bind_( const Val* value, PolymorphicValue concrete_value, @@ -259,6 +265,10 @@ PolymorphicValue ExpressionEvaluator::evaluate(const Val* value) const { const PolymorphicValue& ExpressionEvaluator::evaluate( const Val* value, std::unordered_map& known_values) const { + // FUSER_PERF_SCOPE("ExpressionEvaluator::evaluate"); + // It's tempting to time this function, the issue is it's a recursive function + // so timings produced by it can be accumulatively longer than the actual time + // spent if (precomputed_values_ && precomputed_values_->hasValidValues()) { if (precomputed_values_->getMaybeValueFor(value).hasValue()) { return precomputed_values_->getMaybeValueFor(value); @@ -269,7 +279,6 @@ const PolymorphicValue& ExpressionEvaluator::evaluate( getValue(value, known_values); if (!maybe_concrete_value.get().hasValue()) { if (auto def = value->definition()) { - FUSER_PERF_SCOPE("ExpressionEvaluator::evaluate"); auto outputs = def->evaluate(*this, known_values); for (auto i : c10::irange(def->outputs().size())) { known_values[def->output(i)] = std::move(outputs[i]); @@ -277,6 +286,11 @@ const PolymorphicValue& ExpressionEvaluator::evaluate( maybe_concrete_value = getValue(value, known_values); } } + NVF_ERROR( + maybe_concrete_value.get().hasValue(), + "Error evaluating a value in expression evaluator. Likely ", + value->toString(), + " needs to be bound to a value."); return maybe_concrete_value; } diff --git a/csrc/ir/evaluate.cpp b/csrc/ir/evaluate.cpp index 04ad9132285..2ef2b4a9957 100644 --- a/csrc/ir/evaluate.cpp +++ b/csrc/ir/evaluate.cpp @@ -670,12 +670,26 @@ std::vector ViewOp::evaluate( const std::vector& out_logical = out()->getLogicalDomain(); std::vector out_shape; out_shape.reserve(out_logical.size()); + + int missing_vals = + std::count_if(out_logical.begin(), out_logical.end(), [](IterDomain* id) { + return !id->isDeviceDim() && + !id->getMaybeExpandedExtent()->isConstScalar(); + }); + for (IterDomain* id : out_logical) { if (id->isDeviceDim()) { out_shape.push_back(1); - } else { + } else if (id->getMaybeExpandedExtent()->isConstScalar()) { out_shape.push_back( - ee.evaluate(id->getMaybeExpandedExtent()).as()); + id->getMaybeExpandedExtent()->evaluate().as()); + } else { + if (missing_vals == 1) { + out_shape.push_back(-1); + } else { + out_shape.push_back( + ee.evaluate(id->getMaybeExpandedExtent()).as()); + } } } diff --git a/csrc/runtime/expr_eval_exec.cpp b/csrc/runtime/expr_eval_exec.cpp index 094f0d683b1..e0dd686151d 100644 --- a/csrc/runtime/expr_eval_exec.cpp +++ b/csrc/runtime/expr_eval_exec.cpp @@ -33,13 +33,13 @@ void ExprEvalExecutor::compile(Fusion* fusion) { "ExprEvalExecutor does not support the Fusion provided."); fusion_ = std::make_unique(*fusion); exprs_ = fusion_->exprs(); - // for (auto expr : exprs_) { - // if (expr->isA()) { - // compile(expr->as()); - // } else if (expr->isA()) { - // compile(expr->as()); - // } - // } + for (auto expr : exprs_) { + if (expr->isA()) { + compile(expr->as()); + } else if (expr->isA()) { + compile(expr->as()); + } + } if (isProfilerEnabled()) { FusionProfiler::segment(group_id_).stopCompile(); } @@ -78,28 +78,28 @@ std::vector ExprEvalExecutor::run( FUSER_PERF_SCOPE("ExprEvalExecutor::bindInputs"); // expr_eval = executor_utils::bindInputs(args, fusion_.get()); NVF_ERROR( - fusion_->inputs().size() <= args.size(), - "KernelArgumentHolder contains less argument than fusion's input."); - for(auto inp_i : c10::irange(fusion_->inputs().size())){ + fusion_->inputs().size() <= args.size(), + "KernelArgumentHolder contains less argument than fusion's input."); + for (auto inp_i : c10::irange(fusion_->inputs().size())) { expr_eval.unsafeBind(fusion_->inputs()[inp_i], *args[inp_i]); } } { FUSER_PERF_SCOPE("ExprEvalExecutor::Eval"); - // for (auto expr : exprs_) { - // if (ViewOp* view = dynamic_cast(expr)) { - // auto output_tensor = - // run(view, expr_eval.evaluate(view->in()).as()); - // expr_eval.bind(view->out(), output_tensor); - // continue; - // } else if (LoadStoreOp* ld_st_op = dynamic_cast(expr)) { - // auto output_tensor = - // run(ld_st_op, expr_eval.evaluate(ld_st_op->in()).as()); - // expr_eval.bind(ld_st_op->out(), output_tensor); - // continue; - // } - // expr_eval.evaluate(expr->outputs()[0]); - // } + for (auto expr : exprs_) { + if (ViewOp* view = dynamic_cast(expr)) { + auto output_tensor = + run(view, expr_eval.evaluate(view->in()).as()); + expr_eval.unsafeBind(view->out(), output_tensor); + continue; + } else if (LoadStoreOp* ld_st_op = dynamic_cast(expr)) { + auto output_tensor = + run(ld_st_op, expr_eval.evaluate(ld_st_op->in()).as()); + expr_eval.unsafeBind(ld_st_op->out(), output_tensor); + continue; + } + expr_eval.evaluate(expr->outputs()[0]); + } for (const auto& out_val : fusion_->outputs()) { auto out_tensor = expr_eval.evaluate(out_val).as(); From e2e6f184f63fb1e4af33fea300655cf563ee1d53 Mon Sep 17 00:00:00 2001 From: Christian Sarofeen Date: Sun, 2 Feb 2025 10:38:53 -0800 Subject: [PATCH 08/16] Support dynamic reshape ops in expr eval exec. --- csrc/expr_evaluator.cpp | 11 +++--- csrc/runtime/expr_eval_exec.cpp | 60 ++++++++++++++++++++++----------- csrc/runtime/expr_eval_exec.h | 26 ++++++++++---- 3 files changed, 65 insertions(+), 32 deletions(-) diff --git a/csrc/expr_evaluator.cpp b/csrc/expr_evaluator.cpp index 4e580696f79..cf5646df21e 100644 --- a/csrc/expr_evaluator.cpp +++ b/csrc/expr_evaluator.cpp @@ -286,11 +286,12 @@ const PolymorphicValue& ExpressionEvaluator::evaluate( maybe_concrete_value = getValue(value, known_values); } } - NVF_ERROR( - maybe_concrete_value.get().hasValue(), - "Error evaluating a value in expression evaluator. Likely ", - value->toString(), - " needs to be bound to a value."); + // TODO: Evaluate if an error like below could work + // NVF_ERROR( + // maybe_concrete_value.get().hasValue(), + // "Error evaluating a value in expression evaluator. Likely ", + // value->toString(), + // " needs to be bound to a value."); return maybe_concrete_value; } diff --git a/csrc/runtime/expr_eval_exec.cpp b/csrc/runtime/expr_eval_exec.cpp index e0dd686151d..01414801918 100644 --- a/csrc/runtime/expr_eval_exec.cpp +++ b/csrc/runtime/expr_eval_exec.cpp @@ -88,8 +88,7 @@ std::vector ExprEvalExecutor::run( FUSER_PERF_SCOPE("ExprEvalExecutor::Eval"); for (auto expr : exprs_) { if (ViewOp* view = dynamic_cast(expr)) { - auto output_tensor = - run(view, expr_eval.evaluate(view->in()).as()); + auto output_tensor = run(view, expr_eval); expr_eval.unsafeBind(view->out(), output_tensor); continue; } else if (LoadStoreOp* ld_st_op = dynamic_cast(expr)) { @@ -140,41 +139,62 @@ bool isContiguous(TensorView* tv) { void ExprEvalExecutor::compile(ViewOp* view_op) { FUSER_PERF_SCOPE("ExprEvalExecutor::compile(ViewOp* view_op"); - std::vector sizes; - bool neg_1_found = false; + std::vector sizes; + for (auto id : view_op->out()->getLogicalDomain()) { // Ignore sharded dimensions if (id->isDeviceDim()) { - sizes.push_back(1); + sizes.push_back(FusionGuard::getCurFusion()->oneVal()); continue; } // Constant reshape specified dimensions auto id_size = id->getMaybeExpandedExtent(); - if (id_size->isConstInt()) { - sizes.push_back(id_size->evaluate().as()); + if (id_size->isConstInt() && id_size->definition() != nullptr) { + sizes.push_back( + IrBuilder::create(id_size->evaluate().as())); continue; } - NVF_ERROR( - !neg_1_found, - "Invalid reshape op found, more than one unknown dimensions size specified."); - - // Only one free variable allowed - sizes.push_back(-1); - neg_1_found = true; + sizes.push_back(id_size); } - output_view_sizes[view_op] = sizes; - use_view[view_op] = isContiguous(view_op->in()); + int missing_vals = std::count_if(sizes.begin(), sizes.end(), [](Val* size) { + return !size->isConstScalar(); + }); + + ViewInfo view_info = {sizes, missing_vals <= 1, isContiguous(view_op->in())}; + + view_infos[view_op] = view_info; } -at::Tensor ExprEvalExecutor::run(ViewOp* view_op, at::Tensor input) { +at::Tensor ExprEvalExecutor::run( + ViewOp* view_op, + ExpressionEvaluator& expr_eval) { FUSER_PERF_SCOPE("ExprEvalExecutor::run(ViewOp* view_op"); - if (use_view[view_op]) { - return input.view(output_view_sizes[view_op]); + auto view_info_it = view_infos.find(view_op); + NVF_ERROR( + view_info_it != view_infos.end(), + "Error running ViewOp, it wasn't compiled."); + ViewInfo& view_info = view_info_it->second; + + std::vector sizes; + for (auto size : view_info.output_view_sizes) { + if (size->isConstInt()) { + sizes.push_back(size->value().as()); + } else if (view_info.use_neg_1) { + sizes.push_back(-1); + } else { + expr_eval.evaluate(size).as(); + } + } + + auto input = expr_eval.evaluate(view_op->in()).as(); + + if (view_info.use_at_view) { + return input.view(sizes); } - return input.reshape(output_view_sizes[view_op]); + return input.reshape(sizes); } void ExprEvalExecutor::compile(LoadStoreOp* ld_st_op) { diff --git a/csrc/runtime/expr_eval_exec.h b/csrc/runtime/expr_eval_exec.h index 961bf5e2041..dc7d0d9c1b5 100644 --- a/csrc/runtime/expr_eval_exec.h +++ b/csrc/runtime/expr_eval_exec.h @@ -42,18 +42,30 @@ class ExprEvalExecutor : public ExecutorAbstract { // Expressions to evaluate std::vector exprs_; - // Sizes of the output of view ops, only one value can be unknown at it gets - // processed in aten as a -1 size, every other dim is a constant positive - // integer value. - std::unordered_map> output_view_sizes; - // Indicates if it's safe to use at::view instead of at::reshape - std::unordered_map use_view; + struct ViewInfo { + // Sizes of the output of view ops, only one value can be unknown at it gets + // processed in aten as a -1 size, every other dim is a constant positive + // integer value. + std::vector output_view_sizes; + // PyTorch's API defines all output shapes as a constant known size except + // upto 1 which can be easily inferred based on the input numel and the rest + // of the ouput sizes. nvFuser can have dynamic reshape operations where the + // output sizes are inferred through split and merge operations on IDs. If + // use_neg_1 is true then all values except up to one are constant values. + bool use_neg_1 = false; + // at::view can be used on contiguous tensors and is faster than + // at::reshape. Since we know at compile time if the tensor is contiguous + // then we can route evaluation to view. + bool use_at_view = false; + }; + + std::unordered_map view_infos; // Permute map, stores permutation axes if a LoadStoreOp requires them. std::unordered_map> permutation_orders; void compile(ViewOp* view_op); - at::Tensor run(ViewOp* view_op, at::Tensor input); + at::Tensor run(ViewOp* view_op, ExpressionEvaluator& expr_eval); void compile(LoadStoreOp* ld_st_op); at::Tensor run(LoadStoreOp* ld_st_op, at::Tensor input); From cb41c521c49fa979f5dcb871c07a5dec28314b28 Mon Sep 17 00:00:00 2001 From: Christian Sarofeen Date: Sun, 2 Feb 2025 11:01:46 -0800 Subject: [PATCH 09/16] Simplify tensor extents in eee. --- csrc/device_lower/pass/replace_size.cpp | 3 --- csrc/device_lower/pass/replace_size.h | 19 +++++++++++++++++++ csrc/runtime/expr_eval_exec.cpp | 7 +++++++ 3 files changed, 26 insertions(+), 3 deletions(-) diff --git a/csrc/device_lower/pass/replace_size.cpp b/csrc/device_lower/pass/replace_size.cpp index 2cfe1405050..d4c2ea474aa 100644 --- a/csrc/device_lower/pass/replace_size.cpp +++ b/csrc/device_lower/pass/replace_size.cpp @@ -18,7 +18,6 @@ namespace nvfuser { -namespace { // Going to generate a map of tensor view root domain extents to reduce the // number used during lowering. For example if we have: // @@ -137,8 +136,6 @@ std::unordered_map getSimplificationMap(Fusion* fusion) { return simplification_map; } -} // namespace - void replaceSymbolicSizes(Fusion* fusion) { FUSER_PERF_SCOPE("GpuLower::Lower::replaceSymbolicSizes"); std::unordered_map tensor_dim_map; diff --git a/csrc/device_lower/pass/replace_size.h b/csrc/device_lower/pass/replace_size.h index ca874ab836d..aab690f1df7 100644 --- a/csrc/device_lower/pass/replace_size.h +++ b/csrc/device_lower/pass/replace_size.h @@ -21,4 +21,23 @@ namespace nvfuser { // tensors to reference the runtime structure containing sizes. void replaceSymbolicSizes(Fusion*); +// Going to generate a map of tensor view root domain extents to reduce the +// number used during lowering. For example if we have: +// +// T2[i0, i1] = T1[i0, i1] + T2[i2, i3] +// +// We know it would be safe to use: +// +// T2[i0, i1] = T1[i0, i1] + T2[i0, i1] +// +// And that way we don't generate T2.size[0] and T2.size[1], instead we will +// reuse T1.size[0] and T1.size[1] +// This is important when doing CSE as T2 and T1 would otherwise look like +// they're using different values, even though we know they're the same +// +// There's some duplicate logic here that's in computeAt map, but it's not so +// concice there to pull out. May want to consider making this mapping its own +// class especially as it may be useful during scheduling. +std::unordered_map getSimplificationMap(Fusion* fusion); + } // namespace nvfuser diff --git a/csrc/runtime/expr_eval_exec.cpp b/csrc/runtime/expr_eval_exec.cpp index 01414801918..bbc91e235c6 100644 --- a/csrc/runtime/expr_eval_exec.cpp +++ b/csrc/runtime/expr_eval_exec.cpp @@ -8,8 +8,10 @@ #include +#include #include #include +#include #include @@ -32,6 +34,10 @@ void ExprEvalExecutor::compile(Fusion* fusion) { supported(fusion), "ExprEvalExecutor does not support the Fusion provided."); fusion_ = std::make_unique(*fusion); + + auto extent_simplification_map = getSimplificationMap(fusion_.get()); + auto mutation_map = ir_utils::replaceValue(fusion_.get(), extent_simplification_map); + exprs_ = fusion_->exprs(); for (auto expr : exprs_) { if (expr->isA()) { @@ -39,6 +45,7 @@ void ExprEvalExecutor::compile(Fusion* fusion) { } else if (expr->isA()) { compile(expr->as()); } + //TODO: support RepeatOp and other ops that require ee.evaluate in evaluate.cpp } if (isProfilerEnabled()) { FusionProfiler::segment(group_id_).stopCompile(); From 012099225c4cb400670217a355023d94dc60e76a Mon Sep 17 00:00:00 2001 From: Christian Sarofeen Date: Sun, 2 Feb 2025 17:30:25 -0800 Subject: [PATCH 10/16] Finish implementing dynamic reshape operations and only bind what's necessary for inference. Still need to fix RehsapeToSlice test. --- csrc/runtime/expr_eval_exec.cpp | 139 +++++++++++++++++++++++++++++++- csrc/runtime/expr_eval_exec.h | 20 +++++ 2 files changed, 155 insertions(+), 4 deletions(-) diff --git a/csrc/runtime/expr_eval_exec.cpp b/csrc/runtime/expr_eval_exec.cpp index bbc91e235c6..d32f1b82f5b 100644 --- a/csrc/runtime/expr_eval_exec.cpp +++ b/csrc/runtime/expr_eval_exec.cpp @@ -25,6 +25,63 @@ bool ExprEvalExecutor::supported(Fusion* fusion) { }); } +void ExprEvalExecutor::findAndBindInputTVExtentFrom(Val* val) { + if (val->isFusionInput()) { + return; + } + if (val->isConstInt()) { + return; + } + + auto tv_info_it = extent_to_tv_info.find(val); + if (tv_info_it != extent_to_tv_info.end()) { + tv_sizes_to_bind.push_back(tv_info_it->second); + return; + } + + auto inputs = InputsOf::output(val); + for (auto inp : inputs) { + if (inp->isConstInt()) { + continue; + } + tv_info_it = extent_to_tv_info.find(inp); + NVF_ERROR( + tv_info_it != extent_to_tv_info.end(), + "Error deducing how to infer ", + val->toInlineString()); + tv_sizes_to_bind.push_back(tv_info_it->second); + } +} + +void ExprEvalExecutor::deduplicateTvSizesToBind() { + // Sort by tv pointer, fusion_input_pos (ascending), then logical_dim_pos + // (ascending) + std::sort( + tv_sizes_to_bind.begin(), + tv_sizes_to_bind.end(), + [](const TVInfo& a, const TVInfo& b) { + if (a.tv != b.tv) { + return a.tv < b.tv; + } + + if (a.fusion_input_pos != b.fusion_input_pos) { + return a.fusion_input_pos < b.fusion_input_pos; + } + + return a.logical_dim_pos < b.logical_dim_pos; + }); + + // remove consecutive duplicates + auto last = std::unique( + tv_sizes_to_bind.begin(), + tv_sizes_to_bind.end(), + [](const TVInfo& a, const TVInfo& b) { + return a.tv == b.tv && a.fusion_input_pos == b.fusion_input_pos && + a.logical_dim_pos == b.logical_dim_pos; + }); + tv_sizes_to_bind.erase(last, tv_sizes_to_bind.end()); +} + void ExprEvalExecutor::compile(Fusion* fusion) { FUSER_PERF_SCOPE("ExprEvalExecutor::compile"); if (isProfilerEnabled()) { @@ -34,9 +91,20 @@ void ExprEvalExecutor::compile(Fusion* fusion) { supported(fusion), "ExprEvalExecutor does not support the Fusion provided."); fusion_ = std::make_unique(*fusion); - auto extent_simplification_map = getSimplificationMap(fusion_.get()); - auto mutation_map = ir_utils::replaceValue(fusion_.get(), extent_simplification_map); + auto mutation_map = + ir_utils::replaceValue(fusion_.get(), extent_simplification_map); + + // Build extent to input tv info map + for (auto inp_id : c10::irange(fusion_->inputs().size())) { + if (TensorView* tv = dynamic_cast(fusion_->inputs()[inp_id])) { + auto domain = TensorDomain::noReductions(tv->getLogicalDomain()); + for (auto id_i : c10::irange(domain.size())) { + auto extent = domain[id_i]->getMaybeExpandedExtent(); + extent_to_tv_info[extent] = {tv, inp_id, id_i}; + } + } + } exprs_ = fusion_->exprs(); for (auto expr : exprs_) { @@ -45,8 +113,17 @@ void ExprEvalExecutor::compile(Fusion* fusion) { } else if (expr->isA()) { compile(expr->as()); } - //TODO: support RepeatOp and other ops that require ee.evaluate in evaluate.cpp + // TODO: support RepeatOp and GetMetaData + + for (auto expr_inp : expr->inputs()) { + if (expr_inp->isIntegralScalar()) { + findAndBindInputTVExtentFrom(expr_inp); + } + } } + + deduplicateTvSizesToBind(); + if (isProfilerEnabled()) { FusionProfiler::segment(group_id_).stopCompile(); } @@ -90,6 +167,52 @@ std::vector ExprEvalExecutor::run( for (auto inp_i : c10::irange(fusion_->inputs().size())) { expr_eval.unsafeBind(fusion_->inputs()[inp_i], *args[inp_i]); } + + for (auto tv_info : tv_sizes_to_bind) { + NVF_ERROR( + tv_info.fusion_input_pos < fusion_->inputs().size(), + "Error processing tv_info, asked for fusion input ", + tv_info.fusion_input_pos, + " but fusion only has ", + fusion_->inputs().size(), + " inputs"); + + Val* fusion_input = fusion_->inputs()[tv_info.fusion_input_pos]; + + NVF_ERROR( + fusion_input->isA(), + "Expected provided input to be a tensor view but found ", + fusion_input->toString()); + + auto tv = fusion_input->as(); + + NVF_ERROR( + tv == tv_info.tv, + "Expected fusion input[", + tv_info.fusion_input_pos, + "] to be ", + tv_info.tv->toString(), + " but found ", + tv->toString()); + + auto logical_domain = TensorDomain::noReductions(tv->getLogicalDomain()); + + NVF_ERROR( + tv_info.logical_dim_pos < logical_domain.size(), + "Expected tensor view, ", + tv->toString(), + ", to have a logical domain of size at least ", + tv_info.logical_dim_pos, + " but only found ", + logical_domain.size(), + " dimensions."); + + expr_eval.unsafeBind( + logical_domain[tv_info.logical_dim_pos]->getMaybeExpandedExtent(), + (*args[tv_info.fusion_input_pos]) + .as() + .sizes()[tv_info.logical_dim_pos]); + } } { FUSER_PERF_SCOPE("ExprEvalExecutor::Eval"); @@ -104,7 +227,7 @@ std::vector ExprEvalExecutor::run( expr_eval.unsafeBind(ld_st_op->out(), output_tensor); continue; } - expr_eval.evaluate(expr->outputs()[0]); + auto infer_val = expr_eval.evaluate(expr->outputs()[0]); } for (const auto& out_val : fusion_->outputs()) { @@ -170,6 +293,14 @@ void ExprEvalExecutor::compile(ViewOp* view_op) { return !size->isConstScalar(); }); + // Record which vals need to be inferred and what input bindings we need to + // infer them. + if (missing_vals > 1) { + for (auto size : sizes) { + findAndBindInputTVExtentFrom(size); + } + } + ViewInfo view_info = {sizes, missing_vals <= 1, isContiguous(view_op->in())}; view_infos[view_op] = view_info; diff --git a/csrc/runtime/expr_eval_exec.h b/csrc/runtime/expr_eval_exec.h index dc7d0d9c1b5..9e19a6dc132 100644 --- a/csrc/runtime/expr_eval_exec.h +++ b/csrc/runtime/expr_eval_exec.h @@ -64,6 +64,26 @@ class ExprEvalExecutor : public ExecutorAbstract { // Permute map, stores permutation axes if a LoadStoreOp requires them. std::unordered_map> permutation_orders; + struct TVInfo { + TensorView* tv; + uint64_t fusion_input_pos; + uint64_t logical_dim_pos; + }; + + // Expr eval exec only shallowly binds inputs. This means all sizes of each + // tensor are not bound. During compilation information about which size + // information needs to be pulled and bound are tracked. References entries in + // extent_to_tv_info map. + std::vector tv_sizes_to_bind; + std::unordered_map extent_to_tv_info; + + // Goes to val's inputs and check if it's from a TensorView, if so it fills + // tv_sizes_to_bind for those inputs. + void findAndBindInputTVExtentFrom(Val* val); + + // deduplicate entries in tv_sizes_to_bind + void deduplicateTvSizesToBind(); + void compile(ViewOp* view_op); at::Tensor run(ViewOp* view_op, ExpressionEvaluator& expr_eval); From d9df08c332a78bdb4b49989873562f2732efa57e Mon Sep 17 00:00:00 2001 From: Christian Sarofeen Date: Mon, 3 Feb 2025 07:54:32 -0800 Subject: [PATCH 11/16] Fix ReshapeToSlice test, enable binding input tensor sizes when the scalar's aren't an input value based on the InputsOf function. --- csrc/runtime/expr_eval_exec.cpp | 95 +++++++++++++++------------------ csrc/runtime/expr_eval_exec.h | 36 +++++++++++-- 2 files changed, 74 insertions(+), 57 deletions(-) diff --git a/csrc/runtime/expr_eval_exec.cpp b/csrc/runtime/expr_eval_exec.cpp index d32f1b82f5b..6be16010784 100644 --- a/csrc/runtime/expr_eval_exec.cpp +++ b/csrc/runtime/expr_eval_exec.cpp @@ -12,6 +12,7 @@ #include #include #include +#include #include @@ -25,63 +26,45 @@ bool ExprEvalExecutor::supported(Fusion* fusion) { }); } -void ExprEvalExecutor::findAndBindInputTVExtentFrom(Val* val) { - if (val->isFusionInput()) { - return; - } - if (val->isConstInt()) { - return; +void ExprEvalExecutor::findAndBindInputTVExtentsFrom( + VectorOfUniqueEntries vals) { + for (auto val : vals) { + if (val->isFusionInput()) { + // Could be an input scalar value that will be bound when we bind inputs. + vals.erase(val); + continue; + } + if (val->isConstInt()) { + // Const scalars don't need to be bound + vals.erase(val); + continue; + } + auto tv_info_it = extent_to_tv_info.find(val); + if (tv_info_it != extent_to_tv_info.end()) { + // val is a TV logical ID, use that + tv_sizes_to_bind.pushBack(tv_info_it->second); + vals.erase(val); + } } - auto tv_info_it = extent_to_tv_info.find(val); - if (tv_info_it != extent_to_tv_info.end()) { - tv_sizes_to_bind.push_back(tv_info_it->second); - return; - } + auto deps = DependencyCheck::getAllValsBetween( + all_potential_input_scalars.set(), vals.vector()); + VectorOfUniqueEntries unique_deps(deps); + auto inputs = all_potential_input_scalars.computeIntersect(unique_deps); - auto inputs = InputsOf::output(val); for (auto inp : inputs) { if (inp->isConstInt()) { + // Const scalars don't need to be bound continue; } - tv_info_it = extent_to_tv_info.find(inp); - NVF_ERROR( - tv_info_it != extent_to_tv_info.end(), - "Error deducing how to infer ", - val->toInlineString()); - tv_sizes_to_bind.push_back(tv_info_it->second); + if (inp->isFusionInput()) { + // Could be an input scalar value that will be bound when we bind inputs. + continue; + } + tv_sizes_to_bind.pushBack(extent_to_tv_info[inp]); } } -void ExprEvalExecutor::deduplicateTvSizesToBind() { - // Sort by tv pointer, fusion_input_pos (ascending), then logical_dim_pos - // (ascending) - std::sort( - tv_sizes_to_bind.begin(), - tv_sizes_to_bind.end(), - [](const TVInfo& a, const TVInfo& b) { - if (a.tv != b.tv) { - return a.tv < b.tv; - } - - if (a.fusion_input_pos != b.fusion_input_pos) { - return a.fusion_input_pos < b.fusion_input_pos; - } - - return a.logical_dim_pos < b.logical_dim_pos; - }); - - // remove consecutive duplicates - auto last = std::unique( - tv_sizes_to_bind.begin(), - tv_sizes_to_bind.end(), - [](const TVInfo& a, const TVInfo& b) { - return a.tv == b.tv && a.fusion_input_pos == b.fusion_input_pos && - a.logical_dim_pos == b.logical_dim_pos; - }); - tv_sizes_to_bind.erase(last, tv_sizes_to_bind.end()); -} - void ExprEvalExecutor::compile(Fusion* fusion) { FUSER_PERF_SCOPE("ExprEvalExecutor::compile"); if (isProfilerEnabled()) { @@ -91,6 +74,7 @@ void ExprEvalExecutor::compile(Fusion* fusion) { supported(fusion), "ExprEvalExecutor does not support the Fusion provided."); fusion_ = std::make_unique(*fusion); + auto extent_simplification_map = getSimplificationMap(fusion_.get()); auto mutation_map = ir_utils::replaceValue(fusion_.get(), extent_simplification_map); @@ -102,8 +86,13 @@ void ExprEvalExecutor::compile(Fusion* fusion) { for (auto id_i : c10::irange(domain.size())) { auto extent = domain[id_i]->getMaybeExpandedExtent(); extent_to_tv_info[extent] = {tv, inp_id, id_i}; + all_potential_input_scalars.pushBack( + domain[id_i]->getMaybeExpandedExtent()); } } + if (fusion_->inputs()[inp_id]->isIntegralScalar()) { + all_potential_input_scalars.pushBack(fusion_->inputs()[inp_id]); + } } exprs_ = fusion_->exprs(); @@ -114,15 +103,18 @@ void ExprEvalExecutor::compile(Fusion* fusion) { compile(expr->as()); } // TODO: support RepeatOp and GetMetaData - + NVF_ERROR( + !expr->isA() && !expr->isA(), + "Repeat op and MetaDataOp not implemented yet, found: ", + expr->toString()); for (auto expr_inp : expr->inputs()) { if (expr_inp->isIntegralScalar()) { - findAndBindInputTVExtentFrom(expr_inp); + needed_integer_scalars.pushBack(expr_inp); } } } - deduplicateTvSizesToBind(); + findAndBindInputTVExtentsFrom(needed_integer_scalars); if (isProfilerEnabled()) { FusionProfiler::segment(group_id_).stopCompile(); @@ -206,7 +198,6 @@ std::vector ExprEvalExecutor::run( " but only found ", logical_domain.size(), " dimensions."); - expr_eval.unsafeBind( logical_domain[tv_info.logical_dim_pos]->getMaybeExpandedExtent(), (*args[tv_info.fusion_input_pos]) @@ -297,7 +288,7 @@ void ExprEvalExecutor::compile(ViewOp* view_op) { // infer them. if (missing_vals > 1) { for (auto size : sizes) { - findAndBindInputTVExtentFrom(size); + needed_integer_scalars.pushBack(size); } } diff --git a/csrc/runtime/expr_eval_exec.h b/csrc/runtime/expr_eval_exec.h index 9e19a6dc132..c93f6d61511 100644 --- a/csrc/runtime/expr_eval_exec.h +++ b/csrc/runtime/expr_eval_exec.h @@ -7,6 +7,7 @@ // clang-format on #pragma once +#include #include #include @@ -68,21 +69,46 @@ class ExprEvalExecutor : public ExecutorAbstract { TensorView* tv; uint64_t fusion_input_pos; uint64_t logical_dim_pos; + + bool operator==(const TVInfo& other) const { + return tv == other.tv && fusion_input_pos == other.fusion_input_pos && + logical_dim_pos == other.logical_dim_pos; + } + }; + + // For use with VectorOfUniqueEntries + struct TVInfoHash { + std::size_t operator()(const TVInfo& info) const { + std::size_t hash = 0; + hash ^= std::hash()(info.tv); + hash ^= std::hash()(info.fusion_input_pos); + hash ^= std::hash()(info.logical_dim_pos) << 8; + return hash; + } }; // Expr eval exec only shallowly binds inputs. This means all sizes of each // tensor are not bound. During compilation information about which size // information needs to be pulled and bound are tracked. References entries in // extent_to_tv_info map. - std::vector tv_sizes_to_bind; + VectorOfUniqueEntries tv_sizes_to_bind; std::unordered_map extent_to_tv_info; + // Since input tensor views could be from an intermediate segmentation their + // logical domains could be a function of iter domains of a previous fusions. + // This means an input tensor could have an iter domain for example: + // iS24{( ceilDiv(( i0 * i2 ), 3) )} where i0 and i2 are not "inputs" to + // the fusion. This means we want to bind a the size of the input tensor to + // the entire scalar, not to i0 and i2. This unordered set will contain all + // input scalars and all logical domain scalars of input tensors, to resolve + // how to infer all necessary scalars for the fusion. + VectorOfUniqueEntries all_potential_input_scalars; + + // The scalars that need to be infered during execution. + VectorOfUniqueEntries needed_integer_scalars; // Goes to val's inputs and check if it's from a TensorView, if so it fills // tv_sizes_to_bind for those inputs. - void findAndBindInputTVExtentFrom(Val* val); - - // deduplicate entries in tv_sizes_to_bind - void deduplicateTvSizesToBind(); + void findAndBindInputTVExtentsFrom(VectorOfUniqueEntries vals); void compile(ViewOp* view_op); at::Tensor run(ViewOp* view_op, ExpressionEvaluator& expr_eval); From 9b56257e5a6c5d0a7711c0f7f3007b34936900bd Mon Sep 17 00:00:00 2001 From: Christian Sarofeen Date: Mon, 3 Feb 2025 10:29:31 -0800 Subject: [PATCH 12/16] Add tests and instrument evaluate --- csrc/ir/evaluate.cpp | 37 ++++++ tests/python/llama_inf_tests/graph_0.py | 78 ++++++++++++ tests/python/llama_inf_tests/graph_1.py | 154 ++++++++++++++++++++++++ tests/python/llama_inf_tests/graph_2.py | 105 ++++++++++++++++ 4 files changed, 374 insertions(+) create mode 100644 tests/python/llama_inf_tests/graph_0.py create mode 100644 tests/python/llama_inf_tests/graph_1.py create mode 100644 tests/python/llama_inf_tests/graph_2.py diff --git a/csrc/ir/evaluate.cpp b/csrc/ir/evaluate.cpp index 2ef2b4a9957..ed48c58c8ca 100644 --- a/csrc/ir/evaluate.cpp +++ b/csrc/ir/evaluate.cpp @@ -17,6 +17,7 @@ namespace nvfuser { PolymorphicValue Val::evaluate() { if (this->value().hasValue()) { return this->value(); + FUSER_PERF_SCOPE("Val::evaluate"); } ExpressionEvaluator ee; @@ -31,6 +32,7 @@ PolymorphicValue Val::evaluate() { std::vector Expr::evaluate( const ExpressionEvaluator& ee, const std::vector& inputs) const { + FUSER_PERF_SCOPE("Expr::evaluate"); NVF_THROW( "`evaluate` method for expression ", getOpString(), @@ -41,6 +43,7 @@ std::vector Expr::evaluate( std::vector Expr::evaluate( const ExpressionEvaluator& ee, std::unordered_map& known_values) const { + FUSER_PERF_SCOPE("Expr::evaluate"); std::vector expr_inputs; expr_inputs.reserve(inputs().size()); for (auto inp : inputs()) { @@ -59,6 +62,7 @@ void Expr::addDataAttribute(PolymorphicValue attr) { std::vector FullOp::evaluate( const ExpressionEvaluator& ee, const std::vector& inputs) const { + FUSER_PERF_SCOPE("FullOp::evaluate"); std::vector shape; for (auto i : c10::irange(inputs.size() - 1)) { shape.push_back(inputs.at(i).as()); @@ -73,6 +77,7 @@ std::vector FullOp::evaluate( std::vector SelectOp::evaluate( const ExpressionEvaluator& ee, const std::vector& inputs) const { + FUSER_PERF_SCOPE("SelectOp::evaluate"); const auto& in = inputs.at(0).as(); int64_t dimension = dim(); int64_t index = (int64_t)inputs.at(1); @@ -82,6 +87,7 @@ std::vector SelectOp::evaluate( std::vector IndexSelectOp::evaluate( const ExpressionEvaluator& ee, const std::vector& inputs) const { + FUSER_PERF_SCOPE("IndexSelectOp::evaluate"); const auto& in = inputs.at(0).as(); int64_t dimension = dim(); const auto& indices = inputs.at(1).as().squeeze(); @@ -91,6 +97,7 @@ std::vector IndexSelectOp::evaluate( std::vector TorchGatherOp::evaluate( const ExpressionEvaluator& ee, const std::vector& inputs) const { + FUSER_PERF_SCOPE("TorchGatherOp::evaluate"); const auto& input = inputs.at(0).as(); const auto& index = inputs.at(1).as(); auto dimension = dim(); @@ -104,6 +111,7 @@ std::vector TorchGatherOp::evaluate( std::vector ScatterOp::evaluate( const ExpressionEvaluator& ee, const std::vector& inputs) const { + FUSER_PERF_SCOPE("ScatterOp::evaluate"); const auto& input = inputs.at(0).as(); const auto& index = inputs.at(1).as(); const auto& src = inputs.at(2).as(); @@ -114,6 +122,7 @@ std::vector ScatterOp::evaluate( std::vector IotaOp::evaluate( const ExpressionEvaluator& ee, const std::vector& inputs) const { + FUSER_PERF_SCOPE("IotaOp::evaluate"); const auto options = at::TensorOptions().device(at::kCUDA).dtype(data_type_to_aten(dtype())); int64_t length = (int64_t)inputs.at(0); @@ -139,6 +148,7 @@ std::vector IotaOp::evaluate( std::vector EyeOp::evaluate( const ExpressionEvaluator& ee, const std::vector& inputs) const { + FUSER_PERF_SCOPE("EyeOp::evaluate"); const auto options = at::TensorOptions().device(at::kCUDA).dtype(data_type_to_aten(dtype())); int64_t nrows = (int64_t)inputs.at(0); @@ -153,6 +163,7 @@ std::vector EyeOp::evaluate( std::vector UnaryOp::evaluate( const ExpressionEvaluator& ee, const std::vector& inputs) const { + FUSER_PERF_SCOPE("UnaryOp::evaluate"); using namespace PolymorphicValue_functions; const auto& in = inputs.at(0); @@ -280,6 +291,7 @@ std::vector UnaryOp::evaluate( std::vector BinaryOp::evaluate( const ExpressionEvaluator& ee, const std::vector& inputs) const { + FUSER_PERF_SCOPE("BinaryOp::evaluate"); using namespace PolymorphicValue_functions; const auto& lhs = inputs.at(0); const auto& rhs = inputs.at(1); @@ -376,6 +388,7 @@ std::vector BinaryOp::evaluate( std::vector TernaryOp::evaluate( const ExpressionEvaluator& ee, const std::vector& inputs) const { + FUSER_PERF_SCOPE("TernaryOp::evaluate"); using namespace PolymorphicValue_functions; const auto& a = inputs.at(0); const auto& b = inputs.at(1); @@ -408,12 +421,14 @@ std::vector TernaryOp::evaluate( std::vector ArrayConstruct::evaluate( const ExpressionEvaluator& ee, const std::vector& inputs) const { + FUSER_PERF_SCOPE("ArrayConstruct::evaluate"); return {PolymorphicValue(inputs)}; } std::vector ReverseArray::evaluate( const ExpressionEvaluator& ee, const std::vector& inputs) const { + FUSER_PERF_SCOPE("ReverseArray::evaluate"); NVF_ERROR(inputs.size() == 1, "ReverseArray expects 1 input"); PolymorphicValue array = inputs.at(0); auto& vec = array.as(); @@ -424,6 +439,7 @@ std::vector ReverseArray::evaluate( std::vector GetItem::evaluate( const ExpressionEvaluator& ee, const std::vector& inputs) const { + FUSER_PERF_SCOPE("GetItem::evaluate"); NVF_ERROR(inputs.size() == 2, "GetItem expects 2 inputs"); return {PolymorphicValue(inputs.at(0)[inputs.at(1)])}; } @@ -431,6 +447,7 @@ std::vector GetItem::evaluate( std::vector StructConstruct::evaluate( const ExpressionEvaluator& ee, const std::vector& inputs) const { + FUSER_PERF_SCOPE("StructConstruct::evaluate"); NVF_ERROR( this->inputs().size() == inputs.size(), "StructConstruct expects ", @@ -447,6 +464,7 @@ std::vector StructConstruct::evaluate( std::vector GetAttr::evaluate( const ExpressionEvaluator& ee, const std::vector& inputs) const { + FUSER_PERF_SCOPE("GetAttr::evaluate"); NVF_ERROR(inputs.size() == 1, "GetAttr expects 1 input"); return {inputs.at(0)->*attr()}; } @@ -454,6 +472,7 @@ std::vector GetAttr::evaluate( std::vector TensorConstruct::evaluate( const ExpressionEvaluator& ee, const std::vector& inputs) const { + FUSER_PERF_SCOPE("TensorConstruct::evaluate"); NVF_ERROR(inputs.size() == 1, "TensorConstruct expects 1 input"); using namespace PolymorphicValue_functions; return {toTensor(inputs.at(0))}; @@ -462,6 +481,7 @@ std::vector TensorConstruct::evaluate( std::vector BroadcastOp::evaluate( const ExpressionEvaluator& ee, const std::vector& inputs) const { + FUSER_PERF_SCOPE("BroadcastOp::evaluate"); NVF_ERROR( inputs.size() == 1, "BroadcastOp expects exactly 1 input, but received ", @@ -482,6 +502,7 @@ std::vector BroadcastOp::evaluate( std::vector SqueezeOp::evaluate( const ExpressionEvaluator& ee, const std::vector& inputs) const { + FUSER_PERF_SCOPE("SqueezeOp::evaluate"); NVF_ERROR( inputs.size() == 1, "SqueezeOp expects exactly 1 input, but received ", @@ -511,6 +532,7 @@ std::vector SqueezeOp::evaluate( std::vector ReductionOp::evaluate( const ExpressionEvaluator& ee, const std::vector& inputs) const { + FUSER_PERF_SCOPE("ReductionOp::evaluate"); const auto& input = inputs.at(0).as(); const auto output = out()->as(); @@ -548,6 +570,7 @@ std::vector ReductionOp::evaluate( std::vector GroupedReductionOp::evaluate( const ExpressionEvaluator& ee, const std::vector& inputs) const { + FUSER_PERF_SCOPE("GroupedReductionOp::evaluate"); const auto num_reductions = numHorizontallyGroupedExprs(); std::vector grouped_reduction_out; grouped_reduction_out.reserve(num_reductions); @@ -588,6 +611,7 @@ std::vector GroupedReductionOp::evaluate( std::vector WelfordOp::evaluate( const ExpressionEvaluator& ee, const std::vector& inputs) const { + FUSER_PERF_SCOPE("WelfordOp::evaluate"); NVF_ERROR( !hasInit(), "Evaluation for WelfordOp is not implemented for non-empty initial values."); @@ -614,6 +638,7 @@ std::vector WelfordOp::evaluate( std::vector ExpandOp::evaluate( const ExpressionEvaluator& ee, const std::vector& inputs) const { + FUSER_PERF_SCOPE("ExpandOp::evaluate"); const auto& in = inputs.at(0).as(); std::vector expanded_size; for (auto i : c10::irange(1, inputs.size())) { @@ -625,6 +650,7 @@ std::vector ExpandOp::evaluate( std::vector RepeatOp::evaluate( const ExpressionEvaluator& ee, const std::vector& inputs) const { + FUSER_PERF_SCOPE("RepeatOp::evaluate"); NVF_ERROR( inputs.size() == 1, "RepeatOp expects exactly 1 input, but received ", @@ -656,6 +682,7 @@ std::vector RepeatOp::evaluate( std::vector ViewAsScalar::evaluate( const ExpressionEvaluator& ee, const std::vector& inputs) const { + FUSER_PERF_SCOPE("ViewAsScalar::evaluate"); const at::Tensor& in = inputs.at(0).as(); return {at::view_as_real(in)}; } @@ -664,6 +691,7 @@ std::vector ViewOp::evaluate( const ExpressionEvaluator& ee, const std::vector& inputs) const { FUSER_PERF_SCOPE("ViewOp::evaluate"); + NVF_ERROR(inputs.size() == 1); const at::Tensor& in_tensor = inputs[0].as(); @@ -709,6 +737,7 @@ std::vector LoadStoreOp::evaluate( const ExpressionEvaluator& ee, const std::vector& inputs) const { FUSER_PERF_SCOPE("LoadStoreOp::evaluate"); + if (TensorView* out_tv = dynamic_cast(out())) { if (out_tv->hasRoot()) { std::optional> permutation = @@ -730,6 +759,7 @@ std::vector LoadStoreOp::evaluate( std::vector PadOp::evaluate( const ExpressionEvaluator& ee, const std::vector& inputs) const { + FUSER_PERF_SCOPE("PadOp::evaluate"); const auto& in = inputs.at(0).as(); std::vector pad_widths; @@ -760,6 +790,7 @@ std::vector PadOp::evaluate( std::vector SliceOp::evaluate( const ExpressionEvaluator& ee, const std::vector& inputs) const { + FUSER_PERF_SCOPE("SliceOp::evaluate"); const auto& in = inputs.at(0).as(); std::vector ranges; auto ranges_offset = getRangeInputOffset(); @@ -776,6 +807,7 @@ std::vector SliceOp::evaluate( std::vector CatOp::evaluate( const ExpressionEvaluator& ee, std::unordered_map& known_values) const { + FUSER_PERF_SCOPE("CatOp::evaluate"); // CatOp is preceded by a PadOp internally. // For ATen evaluation, directly compute the unpadded inputs. std::vector unpadded_inputs; @@ -794,6 +826,7 @@ std::vector CatOp::evaluate( std::vector MatmulOp::evaluate( const ExpressionEvaluator& ee, const std::vector& inputs) const { + FUSER_PERF_SCOPE("MatmulOp::evaluate"); const auto a = inputs.at(0).as(); const auto b = inputs.at(1).as(); @@ -839,6 +872,7 @@ std::vector MatmulOp::evaluate( std::vector LinearOp::evaluate( const ExpressionEvaluator& ee, const std::vector& inputs) const { + FUSER_PERF_SCOPE("LinearOp::evaluate"); const auto in = inputs.at(0).as(); auto weight = inputs.at(1).as(); @@ -880,6 +914,7 @@ std::vector LinearOp::evaluate( std::vector SdpaFwdOp::evaluate( const ExpressionEvaluator& ee, const std::vector& inputs) const { + FUSER_PERF_SCOPE("SdpaFwdOp::evaluate"); auto query = inputs.at(0).as(); auto key = inputs.at(1).as(); auto value = inputs.at(2).as(); @@ -980,6 +1015,7 @@ std::vector SdpaFwdOp::evaluate( std::vector SdpaBwdOp::evaluate( const ExpressionEvaluator& ee, const std::vector& inputs) const { + FUSER_PERF_SCOPE("SdpaBwdOp::evaluate"); // Backward tensor inputs: grad_input, query, key, value, output, // logsumexp, max_q/k Temporary handling of DID parallelization. See // https://github.com/NVIDIA/Fuser/issues/2563 @@ -1068,6 +1104,7 @@ std::vector SdpaBwdOp::evaluate( std::vector EmbeddingFwdOp::evaluate( const ExpressionEvaluator& ee, const std::vector& inputs) const { + FUSER_PERF_SCOPE("EmbeddingFwdOp::evaluate"); auto input = inputs.at(0).as(); auto weight = inputs.at(1).as(); auto norm_type = inputs.at(2).as(); diff --git a/tests/python/llama_inf_tests/graph_0.py b/tests/python/llama_inf_tests/graph_0.py new file mode 100644 index 00000000000..928eb8ac2e5 --- /dev/null +++ b/tests/python/llama_inf_tests/graph_0.py @@ -0,0 +1,78 @@ +import torch +from nvfuser import FusionDefinition, DataType +import time + +def nvfuser_fusion_id0(fd : FusionDefinition) -> None : + T0 = fd.define_tensor(shape=[1, 6], contiguity=[None, True], dtype=DataType.Int, is_cpu=False, stride_order=[1, 0]) + T1 = fd.define_tensor(shape=[128256, 2048], contiguity=[True, True], dtype=DataType.BFloat16, is_cpu=False, stride_order=[1, 0]) + T2 = fd.define_tensor(shape=[1, 6], contiguity=[None, True], dtype=DataType.Int, is_cpu=False, stride_order=[1, 0]) + S3 = fd.define_scalar(2.00000, dtype=DataType.Double) + S4 = fd.define_scalar(False, dtype=DataType.Bool) + S5 = fd.define_scalar(False, dtype=DataType.Bool) + T6 = fd.ops.embedding_fwd(T0, T1, None, None, S3, S4, S5) + S7 = fd.define_scalar(6, dtype=DataType.Int) + S8 = fd.define_scalar(0, dtype=DataType.Int) + S9 = fd.define_scalar(1, dtype=DataType.Int) + T10 = fd.ops.iota(S7, S8, S9, dtype=DataType.Int) + T14 = fd.ops.broadcast_in_dim(T10, shape=[1, 6], broadcast_dims=[1]) + S15 = fd.define_scalar(-3.38953e+38, dtype=DataType.Double) + T19 = fd.ops.full(shape=[6, 6], fill_value=S15, dtype=DataType.BFloat16) + T23 = fd.ops.broadcast_in_dim(T10, shape=[6, 1], broadcast_dims=[0]) + T27 = fd.ops.broadcast_in_dim(T14, shape=[6, 6], broadcast_dims=[0, 1]) + T31 = fd.ops.broadcast_in_dim(T23, shape=[6, 6], broadcast_dims=[0, 1]) + T32 = fd.ops.sub(T27, T31) + S33 = fd.define_scalar(1, dtype=DataType.Int) + T34 = fd.ops.ge(T32, S33) + S35 = fd.define_scalar(0.00000, dtype=DataType.Double) + T36 = fd.ops.where(T34, T19, S35) + T40 = fd.ops.reshape(T10, new_shape=[6, 1]) + T44 = fd.ops.broadcast_in_dim(T10, shape=[6, 6], broadcast_dims=[1]) + T48 = fd.ops.broadcast_in_dim(T40, shape=[6, 6], broadcast_dims=[0, 1]) + T49 = fd.ops.gt(T44, T48) + T50 = fd.ops.cast(T36, dtype=DataType.Float) + T51 = fd.ops.cast(T49, dtype=DataType.Float) + T52 = fd.ops.mul(T50, T51) + T53 = fd.ops.cast(T52, dtype=DataType.BFloat16) + T59 = fd.ops.broadcast_in_dim(T53, shape=[1, 1, 6, 6], broadcast_dims=[2, 3]) + T65 = fd.ops.broadcast_in_dim(T59, shape=[1, 1, 6, 6], broadcast_dims=[0, 1, 2, 3]) + T66 = fd.ops.set(T65) + T72 = fd.ops.broadcast_in_dim(T2, shape=[1, 1, 1, 6], broadcast_dims=[0, 3]) + T78 = fd.ops.broadcast_in_dim(T72, shape=[1, 1, 6, 6], broadcast_dims=[0, 1, 2, 3]) + T79 = fd.ops.cast(T66, dtype=DataType.Float) + T80 = fd.ops.cast(T78, dtype=DataType.Float) + T81 = fd.ops.add(T79, T80) + T82 = fd.ops.cast(T81, dtype=DataType.BFloat16) + S83 = fd.define_scalar(0.00000, dtype=DataType.Double) + T84 = fd.ops.eq(T82, S83) + S85 = fd.define_scalar(-3.38953e+38, dtype=DataType.Double) + T86 = fd.ops.where(T84, S85, T66) + fd.add_output(T6) + fd.add_output(T66) + fd.add_output(T86) + +with FusionDefinition() as fd: + nvfuser_fusion_id0(fd) + +inputs = [ + torch.ones((1, 6), dtype=torch.int64, device='cuda:0'), + torch.testing.make_tensor((128256, 2048), dtype=torch.bfloat16, device='cuda:0'), + torch.ones((1, 6), dtype=torch.int64, device='cuda:0'), +] + +fd.execute(inputs) + +# for _ in range(3): +# fd.execute(inputs) + +# torch.cuda.synchronize() +# start = time.time() +# # Mark the profiling region +# torch.cuda.cudart().cudaProfilerStart() + +# for _ in range(100): +# fd.execute(inputs) + +# torch.cuda.cudart().cudaProfilerStop() +# torch.cuda.synchronize() +# end = time.time() +# print(end-start) \ No newline at end of file diff --git a/tests/python/llama_inf_tests/graph_1.py b/tests/python/llama_inf_tests/graph_1.py new file mode 100644 index 00000000000..dfa92968a47 --- /dev/null +++ b/tests/python/llama_inf_tests/graph_1.py @@ -0,0 +1,154 @@ +import torch +from nvfuser import FusionDefinition, DataType +import time + +def nvfuser_fusion_id1(fd : FusionDefinition) -> None : + T0 = fd.define_tensor(shape=[1, 6, 2048], contiguity=[None, True, True], dtype=DataType.BFloat16, is_cpu=False, stride_order=[2, 1, 0]) + T1 = fd.define_tensor(shape=[2048], contiguity=[True], dtype=DataType.BFloat16, is_cpu=False, stride_order=[0]) + T2 = fd.define_tensor(shape=[32], contiguity=[True], dtype=DataType.BFloat16, is_cpu=False, stride_order=[0]) + T3 = fd.define_tensor(shape=[512, 2048], contiguity=[True, True], dtype=DataType.BFloat16, is_cpu=False, stride_order=[1, 0]) + T4 = fd.define_tensor(shape=[2048, 2048], contiguity=[True, True], dtype=DataType.BFloat16, is_cpu=False, stride_order=[1, 0]) + T5 = fd.define_tensor(shape=[1, 1, 6, 6], contiguity=[True, None, None, True], dtype=DataType.BFloat16, is_cpu=False, stride_order=[2, 1, 3, 0]) + T6 = fd.define_tensor(shape=[512, 2048], contiguity=[True, True], dtype=DataType.BFloat16, is_cpu=False, stride_order=[1, 0]) + T7 = fd.ops.cast(T0, dtype=DataType.Float) + S8 = fd.define_scalar(2.00000, dtype=DataType.Double) + T9 = fd.ops.pow(T7, S8) + T10 = fd.ops.sum(T9, dims=[2], keepdim=False, dtype=DataType.Null) + T15 = fd.ops.broadcast_in_dim(T10, shape=[1, 6, 1], broadcast_dims=[0, 1]) + S16 = fd.define_scalar(2048.00, dtype=DataType.Double) + S17 = fd.ops.reciprocal(S16) + T18 = fd.ops.mul(T15, S17) + S19 = fd.define_scalar(1.00000e-05, dtype=DataType.Double) + T20 = fd.ops.add(T18, S19) + T21 = fd.ops.rsqrt(T20) + T26 = fd.ops.broadcast_in_dim(T21, shape=[1, 6, 2048], broadcast_dims=[0, 1, 2]) + T27 = fd.ops.mul(T7, T26) + S28 = fd.define_scalar(6, dtype=DataType.Int) + S29 = fd.define_scalar(0, dtype=DataType.Int) + S30 = fd.define_scalar(1, dtype=DataType.Int) + T31 = fd.ops.iota(S28, S29, S30, dtype=DataType.Int) + T36 = fd.ops.broadcast_in_dim(T1, shape=[1, 6, 2048], broadcast_dims=[2]) + T40 = fd.ops.broadcast_in_dim(T31, shape=[1, 6], broadcast_dims=[1]) + T45 = fd.ops.broadcast_in_dim(T2, shape=[1, 32, 1], broadcast_dims=[1]) + T46 = fd.ops.cast(T36, dtype=DataType.Float) + T51 = fd.ops.broadcast_in_dim(T40, shape=[1, 1, 6], broadcast_dims=[0, 2]) + T52 = fd.ops.cast(T45, dtype=DataType.Float) + T53 = fd.ops.mul(T46, T27) + T54 = fd.ops.cast(T51, dtype=DataType.Float) + T59 = fd.ops.broadcast_in_dim(T52, shape=[1, 32, 1], broadcast_dims=[0, 1, 2]) + T60 = fd.ops.cast(T53, dtype=DataType.BFloat16) + T61 = fd.ops.matmul(T59, T54) + T62 = fd.ops.linear(T60, T3) + T63 = fd.ops.permute(T61, dims=[0, 2, 1]) + T69 = fd.ops.reshape(T62, new_shape=[1, 6, 8, 64]) + T70 = fd.ops.cat([T63, T63], dim=-1, manual_padding=0) + T71 = fd.ops.permute(T69, dims=[0, 2, 1, 3]) + T72 = fd.ops.sin(T70) + T88 = fd.ops.slice(T71, start_indices=[0, 0, 0, 32], end_indices=[1, 8, 6, 64], strides=[1, 1, 1, 1], manual_normalization=0) + T89 = fd.ops.cos(T70) + T90 = fd.ops.linear(T60, T4) + S91 = fd.define_scalar(1.00000, dtype=DataType.Double) + T92 = fd.ops.mul(T72, S91) + T93 = fd.ops.cast(T88, dtype=DataType.Float) + S94 = fd.define_scalar(1.00000, dtype=DataType.Double) + T95 = fd.ops.mul(T89, S94) + T101 = fd.ops.reshape(T90, new_shape=[1, 6, 32, 64]) + T102 = fd.ops.cast(T92, dtype=DataType.BFloat16) + T103 = fd.ops.neg(T93) + T104 = fd.ops.cast(T95, dtype=DataType.BFloat16) + T105 = fd.ops.permute(T101, dims=[0, 2, 1, 3]) + T111 = fd.ops.broadcast_in_dim(T102, shape=[1, 1, 6, 64], broadcast_dims=[0, 2, 3]) + T127 = fd.ops.slice(T71, start_indices=[0, 0, 0, 0], end_indices=[1, 8, 6, 32], strides=[1, 1, 1, 1], manual_normalization=0) + T128 = fd.ops.cast(T103, dtype=DataType.BFloat16) + T134 = fd.ops.broadcast_in_dim(T104, shape=[1, 1, 6, 64], broadcast_dims=[0, 2, 3]) + T150 = fd.ops.slice(T105, start_indices=[0, 0, 0, 32], end_indices=[1, 32, 6, 64], strides=[1, 1, 1, 1], manual_normalization=0) + S151 = fd.define_scalar(-3.38953e+38, dtype=DataType.Double) + T152 = fd.ops.eq(T5, S151) + T158 = fd.ops.broadcast_in_dim(T111, shape=[1, 8, 6, 64], broadcast_dims=[0, 1, 2, 3]) + T159 = fd.ops.cat([T128, T127], dim=-1, manual_padding=0) + T165 = fd.ops.broadcast_in_dim(T134, shape=[1, 8, 6, 64], broadcast_dims=[0, 1, 2, 3]) + T166 = fd.ops.cast(T150, dtype=DataType.Float) + T167 = fd.ops.bitwise_not(T152) + T168 = fd.ops.cast(T158, dtype=DataType.Float) + T169 = fd.ops.cast(T159, dtype=DataType.Float) + T170 = fd.ops.cast(T165, dtype=DataType.Float) + T171 = fd.ops.cast(T71, dtype=DataType.Float) + T172 = fd.ops.neg(T166) + T173 = fd.ops.cast(T167, dtype=DataType.Int) + T174 = fd.ops.mul(T169, T168) + T175 = fd.ops.mul(T171, T170) + T191 = fd.ops.slice(T105, start_indices=[0, 0, 0, 0], end_indices=[1, 32, 6, 32], strides=[1, 1, 1, 1], manual_normalization=0) + T192 = fd.ops.cast(T172, dtype=DataType.BFloat16) + T193 = fd.ops.sum(T173, dims=[3], keepdim=False, dtype=DataType.Null) + T199 = fd.ops.broadcast_in_dim(T111, shape=[1, 32, 6, 64], broadcast_dims=[0, 1, 2, 3]) + T200 = fd.ops.cat([T192, T191], dim=-1, manual_padding=0) + T206 = fd.ops.broadcast_in_dim(T134, shape=[1, 32, 6, 64], broadcast_dims=[0, 1, 2, 3]) + T212 = fd.ops.broadcast_in_dim(T193, shape=[1, 1, 6, 1], broadcast_dims=[0, 1, 2]) + T213 = fd.ops.linear(T60, T6) + T214 = fd.ops.cast(T199, dtype=DataType.Float) + T215 = fd.ops.cast(T200, dtype=DataType.Float) + T216 = fd.ops.cast(T206, dtype=DataType.Float) + T217 = fd.ops.cast(T105, dtype=DataType.Float) + S218 = fd.define_scalar(0, dtype=DataType.Int) + T219 = fd.ops.ne(T212, S218) + T225 = fd.ops.reshape(T213, new_shape=[1, 6, 8, 64]) + T226 = fd.ops.add(T175, T174) + T227 = fd.ops.mul(T215, T214) + T228 = fd.ops.mul(T217, T216) + T229 = fd.ops.bitwise_not(T219) + T230 = fd.ops.permute(T225, dims=[0, 2, 1, 3]) + T231 = fd.ops.cast(T226, dtype=DataType.BFloat16) + T232 = fd.ops.bitwise_not(T229) + T239 = fd.ops.broadcast_in_dim(T230, shape=[1, 8, 1, 6, 64], broadcast_dims=[0, 1, 3, 4]) + T246 = fd.ops.broadcast_in_dim(T231, shape=[1, 8, 1, 6, 64], broadcast_dims=[0, 1, 3, 4]) + T252 = fd.ops.broadcast_in_dim(T232, shape=[1, 1, 6, 6], broadcast_dims=[0, 1, 2, 3]) + T259 = fd.ops.broadcast_in_dim(T239, shape=[1, 8, 4, 6, 64], broadcast_dims=[0, 1, 2, 3, 4]) + T266 = fd.ops.broadcast_in_dim(T246, shape=[1, 8, 4, 6, 64], broadcast_dims=[0, 1, 2, 3, 4]) + T267 = fd.ops.add(T228, T227) + T268 = fd.ops.cast(T252, dtype=DataType.Float) + T269 = fd.ops.cast(T5, dtype=DataType.Float) + T275 = fd.ops.reshape(T259, new_shape=[1, 32, 6, 64]) + T281 = fd.ops.reshape(T266, new_shape=[1, 32, 6, 64]) + T282 = fd.ops.cast(T267, dtype=DataType.BFloat16) + T283 = fd.ops.mul(T269, T268) + T284 = fd.ops.stride_order(T275, stride_order=[3, 2, 1, 0]) + T285 = fd.ops.stride_order(T281, stride_order=[3, 2, 1, 0]) + T286 = fd.ops.stride_order(T282, stride_order=[3, 2, 1, 0]) + T287 = fd.ops.cast(T283, dtype=DataType.BFloat16) + fd.add_output(T287) + fd.add_output(T230) + fd.add_output(T231) + fd.add_output(T286) + fd.add_output(T285) + fd.add_output(T284) + +with FusionDefinition() as fd: + nvfuser_fusion_id1(fd) + +inputs = [ + torch.testing.make_tensor((1, 6, 2048), dtype=torch.bfloat16, device='cuda:0'), + torch.testing.make_tensor((2048,), dtype=torch.bfloat16, device='cuda:0'), + torch.testing.make_tensor((32,), dtype=torch.bfloat16, device='cuda:0'), + torch.testing.make_tensor((512, 2048), dtype=torch.bfloat16, device='cuda:0'), + torch.testing.make_tensor((2048, 2048), dtype=torch.bfloat16, device='cuda:0'), + torch.testing.make_tensor((1, 1, 6, 6), dtype=torch.bfloat16, device='cuda:0'), + torch.testing.make_tensor((512, 2048), dtype=torch.bfloat16, device='cuda:0'), +] + +fd.execute(inputs) + +# for _ in range(3): +# fd.execute(inputs) + +# torch.cuda.synchronize() +# start = time.time() +# # Mark the profiling region +# torch.cuda.cudart().cudaProfilerStart() + +# for _ in range(100): +# fd.execute(inputs) + +# torch.cuda.cudart().cudaProfilerStop() +# torch.cuda.synchronize() +# end = time.time() +# print(end-start) \ No newline at end of file diff --git a/tests/python/llama_inf_tests/graph_2.py b/tests/python/llama_inf_tests/graph_2.py new file mode 100644 index 00000000000..753115abd61 --- /dev/null +++ b/tests/python/llama_inf_tests/graph_2.py @@ -0,0 +1,105 @@ +import torch +from nvfuser import FusionDefinition, DataType +import time + +def nvfuser_fusion_id2(fd : FusionDefinition) -> None : + T0 = fd.define_tensor(shape=[1, 32, 6, 64], contiguity=[None, True, True, True], dtype=DataType.BFloat16, is_cpu=False, stride_order=[3, 1, 2, 0]) + T1 = fd.define_tensor(shape=[2048, 2048], contiguity=[True, True], dtype=DataType.BFloat16, is_cpu=False, stride_order=[1, 0]) + T2 = fd.define_tensor(shape=[1, 6, 2048], contiguity=[None, True, True], dtype=DataType.BFloat16, is_cpu=False, stride_order=[2, 1, 0]) + T3 = fd.define_tensor(shape=[2048], contiguity=[True], dtype=DataType.BFloat16, is_cpu=False, stride_order=[0]) + T4 = fd.define_tensor(shape=[8192, 2048], contiguity=[True, True], dtype=DataType.BFloat16, is_cpu=False, stride_order=[1, 0]) + T5 = fd.define_tensor(shape=[8192, 2048], contiguity=[True, True], dtype=DataType.BFloat16, is_cpu=False, stride_order=[1, 0]) + T6 = fd.define_tensor(shape=[2048, 8192], contiguity=[True, True], dtype=DataType.BFloat16, is_cpu=False, stride_order=[1, 0]) + T7 = fd.define_tensor(shape=[2048], contiguity=[True], dtype=DataType.BFloat16, is_cpu=False, stride_order=[0]) + T8 = fd.define_tensor(shape=[128256, 2048], contiguity=[True, True], dtype=DataType.BFloat16, is_cpu=False, stride_order=[1, 0]) + T9 = fd.ops.permute(T0, dims=[0, 2, 1, 3]) + T10 = fd.ops.stride_order(T9, stride_order=[3, 2, 1, 0]) + T15 = fd.ops.reshape(T10, new_shape=[1, 6, 2048]) + T16 = fd.ops.stride_order(T15, stride_order=[2, 1, 0]) + T17 = fd.ops.linear(T16, T1) + T18 = fd.ops.cast(T2, dtype=DataType.Float) + T19 = fd.ops.cast(T17, dtype=DataType.Float) + T20 = fd.ops.add(T18, T19) + S21 = fd.define_scalar(2.00000, dtype=DataType.Double) + T22 = fd.ops.pow(T20, S21) + T23 = fd.ops.sum(T22, dims=[2], keepdim=False, dtype=DataType.Null) + T28 = fd.ops.broadcast_in_dim(T23, shape=[1, 6, 1], broadcast_dims=[0, 1]) + S29 = fd.define_scalar(2048.00, dtype=DataType.Double) + S30 = fd.ops.reciprocal(S29) + T31 = fd.ops.mul(T28, S30) + S32 = fd.define_scalar(1.00000e-05, dtype=DataType.Double) + T33 = fd.ops.add(T31, S32) + T34 = fd.ops.rsqrt(T33) + T39 = fd.ops.broadcast_in_dim(T34, shape=[1, 6, 2048], broadcast_dims=[0, 1, 2]) + T40 = fd.ops.mul(T20, T39) + T45 = fd.ops.broadcast_in_dim(T3, shape=[1, 6, 2048], broadcast_dims=[2]) + T46 = fd.ops.cast(T45, dtype=DataType.Float) + T47 = fd.ops.mul(T46, T40) + T48 = fd.ops.cast(T47, dtype=DataType.BFloat16) + T49 = fd.ops.linear(T48, T4) + T50 = fd.ops.cast(T49, dtype=DataType.Float) + T51 = fd.ops.neg(T50) + T52 = fd.ops.exp(T51) + S53 = fd.define_scalar(1.00000, dtype=DataType.Double) + T54 = fd.ops.add(S53, T52) + T55 = fd.ops.reciprocal(T54) + T56 = fd.ops.mul(T50, T55) + T57 = fd.ops.linear(T48, T5) + T58 = fd.ops.cast(T57, dtype=DataType.Float) + T59 = fd.ops.mul(T56, T58) + T60 = fd.ops.cast(T59, dtype=DataType.BFloat16) + T61 = fd.ops.linear(T60, T6) + T62 = fd.ops.cast(T61, dtype=DataType.Float) + T63 = fd.ops.add(T20, T62) + S64 = fd.define_scalar(2.00000, dtype=DataType.Double) + T65 = fd.ops.pow(T63, S64) + T66 = fd.ops.sum(T65, dims=[2], keepdim=False, dtype=DataType.Null) + T71 = fd.ops.broadcast_in_dim(T66, shape=[1, 6, 1], broadcast_dims=[0, 1]) + S72 = fd.define_scalar(2048.00, dtype=DataType.Double) + S73 = fd.ops.reciprocal(S72) + T74 = fd.ops.mul(T71, S73) + S75 = fd.define_scalar(1.00000e-05, dtype=DataType.Double) + T76 = fd.ops.add(T74, S75) + T77 = fd.ops.rsqrt(T76) + T82 = fd.ops.broadcast_in_dim(T77, shape=[1, 6, 2048], broadcast_dims=[0, 1, 2]) + T83 = fd.ops.mul(T63, T82) + T88 = fd.ops.broadcast_in_dim(T7, shape=[1, 6, 2048], broadcast_dims=[2]) + T89 = fd.ops.cast(T88, dtype=DataType.Float) + T90 = fd.ops.mul(T89, T83) + T91 = fd.ops.cast(T90, dtype=DataType.BFloat16) + T92 = fd.ops.linear(T91, T8) + fd.add_output(T92) + +with FusionDefinition() as fd: + nvfuser_fusion_id2(fd) + +inputs = [ + torch.randn(12288, dtype=torch.bfloat16, device='cuda:0').as_strided((1, 32, 6, 64), (12288, 64, 2048, 1)), + torch.testing.make_tensor((2048, 2048), dtype=torch.bfloat16, device='cuda:0'), + torch.testing.make_tensor((1, 6, 2048), dtype=torch.bfloat16, device='cuda:0'), + torch.testing.make_tensor((2048,), dtype=torch.bfloat16, device='cuda:0'), + torch.testing.make_tensor((8192, 2048), dtype=torch.bfloat16, device='cuda:0'), + torch.testing.make_tensor((8192, 2048), dtype=torch.bfloat16, device='cuda:0'), + torch.testing.make_tensor((2048, 8192), dtype=torch.bfloat16, device='cuda:0'), + torch.testing.make_tensor((2048,), dtype=torch.bfloat16, device='cuda:0'), + torch.testing.make_tensor((128256, 2048), dtype=torch.bfloat16, device='cuda:0'), +] + +fd.execute(inputs) + + +# for _ in range(3): +# fd.execute(inputs) + +# torch.cuda.synchronize() +# start = time.time() +# # Mark the profiling region +# torch.cuda.cudart().cudaProfilerStart() + +# for _ in range(100): +# fd.execute(inputs) + +# torch.cuda.cudart().cudaProfilerStop() +# torch.cuda.synchronize() +# end = time.time() +# print(end-start) \ No newline at end of file From a38ce70af5428d8963d5a0c5dbebff8beb1f5449 Mon Sep 17 00:00:00 2001 From: Christian Sarofeen Date: Tue, 4 Feb 2025 10:38:40 -0800 Subject: [PATCH 13/16] Extraneous start profiling call. --- csrc/runtime/expr_eval_exec.cpp | 1 - 1 file changed, 1 deletion(-) diff --git a/csrc/runtime/expr_eval_exec.cpp b/csrc/runtime/expr_eval_exec.cpp index 6be16010784..6c484d753c2 100644 --- a/csrc/runtime/expr_eval_exec.cpp +++ b/csrc/runtime/expr_eval_exec.cpp @@ -119,7 +119,6 @@ void ExprEvalExecutor::compile(Fusion* fusion) { if (isProfilerEnabled()) { FusionProfiler::segment(group_id_).stopCompile(); } - cudaProfilerStart(); } bool ExprEvalExecutor::isCompiled() const { From 73c0deebc4b2a2636b56800bffef52ffa847b85a Mon Sep 17 00:00:00 2001 From: Christian Sarofeen Date: Sun, 2 Mar 2025 13:45:53 -0800 Subject: [PATCH 14/16] Update with executor_cleanup results. --- tests/python/llama_inf_tests/graph_0.py | 77 ++++++++++++++++++++---- tests/python/llama_inf_tests/graph_1.py | 78 +++++++++++++++++++++---- tests/python/llama_inf_tests/graph_2.py | 77 ++++++++++++++++++++---- 3 files changed, 196 insertions(+), 36 deletions(-) diff --git a/tests/python/llama_inf_tests/graph_0.py b/tests/python/llama_inf_tests/graph_0.py index 928eb8ac2e5..0bfc9db0d00 100644 --- a/tests/python/llama_inf_tests/graph_0.py +++ b/tests/python/llama_inf_tests/graph_0.py @@ -61,18 +61,71 @@ def nvfuser_fusion_id0(fd : FusionDefinition) -> None : fd.execute(inputs) -# for _ in range(3): -# fd.execute(inputs) +for _ in range(3): + fd.execute(inputs) -# torch.cuda.synchronize() -# start = time.time() -# # Mark the profiling region -# torch.cuda.cudart().cudaProfilerStart() +torch.cuda.synchronize() +start = time.time() +# Mark the profiling region +torch.cuda.cudart().cudaProfilerStart() -# for _ in range(100): -# fd.execute(inputs) +for _ in range(100): + fd.execute(inputs) -# torch.cuda.cudart().cudaProfilerStop() -# torch.cuda.synchronize() -# end = time.time() -# print(end-start) \ No newline at end of file +torch.cuda.cudart().cudaProfilerStop() +torch.cuda.synchronize() +end = time.time() + +print((end-start)*1000, " ms") + +# Before: +# 12.0 ms +# After: +# 3.1 ms + +# rm report* +# nsys profile -c cudaProfilerApi python tests/python/llama_inf_tests/graph_0.py +# nsys stats report1.nsys-rep + +# Before: +# Time (%) Total Time (ns) Instances Avg (ns) Med (ns) Min (ns) Max (ns) StdDev (ns) Style Range +# -------- --------------- --------- -------- -------- -------- -------- ----------- ------- ---------------------------------------------- +# 13.8 10011392 100 100113.9 80400.0 76319 768432 82944.7 PushPop :FusionExecutorCache::runFusionWithInputs +# 13.0 9409367 100 94093.7 77940.0 74188 765647 79435.0 PushPop :FusionKernelRuntime::runWithInputs +# 12.9 9353511 100 93535.1 77347.0 73635 764599 79335.4 PushPop :FusionKernelRuntime::runSegmentsWithInputs +# 12.4 8989375 300 29964.6 26494.0 12397 698157 44537.4 PushPop :FusionKernelRuntime::runKernelWithInput +# 12.3 8896373 300 29654.6 26056.5 12135 697796 44508.6 PushPop :ExecutorDispatch::run2 +# 10.1 7309840 200 36549.2 31871.0 24321 697376 51775.8 PushPop :KernelExecutor::runFusion +# 6.7 4859672 200 24298.4 22950.5 13246 684391 48960.1 PushPop :KernelExecutor::runFusion::execute_kernel +# 5.9 4316308 1200 3596.9 2396.0 1980 175457 7203.4 PushPop :ExpressionEvaluator::evaluate +# 5.6 4086264 200 20431.3 19635.5 10005 674394 48476.2 PushPop :KernelExecutor::recomputeArgs +# 2.0 1455689 100 14556.9 12596.0 11930 176349 16430.4 PushPop :ExprEvalExecutor::run +# 1.9 1362206 200 6811.0 7320.5 3864 174236 12085.8 PushPop :fusion_executor::allocations::allocateOutputs +# 1.4 997365 600 1662.3 1368.0 1205 167712 6793.6 PushPop :fusion_executor::allocations::allocateTensor +# 1.0 690717 200 3453.6 3288.0 2816 10209 890.9 PushPop :ExecutorRunFusion::cuLaunchKernel +# 0.4 294173 300 980.6 831.0 107 9627 925.9 PushPop :executor_utils::bindInputs +# 0.3 228065 300 760.2 152.5 122 165328 9538.9 PushPop :ExecutorDispatch::isCompiled +# 0.3 192836 200 964.2 108.0 99 164903 11650.7 PushPop :KernelExecutor::runFusion::intermediates +# 0.1 75680 100 756.8 712.0 459 4308 388.0 PushPop :FusionExecutorCache::setCacheId +# 0.0 17393 100 173.9 133.0 112 875 100.7 PushPop :FusionExecutorCache::getKernelRuntimeFor + +# After: +# Time (%) Total Time (ns) Instances Avg (ns) Med (ns) Min (ns) Max (ns) StdDev (ns) Style Range +# -------- --------------- --------- -------- -------- -------- -------- ----------- ------- ---------------------------------------------- +# 17.1 5182038 100 51820.4 40488.5 38316 309012 45433.7 PushPop :FusionExecutorCache::runFusionWithInputs +# 15.5 4712599 100 47126.0 38026.5 36027 293111 40961.2 PushPop :FusionKernelRuntime::runWithInputs +# 15.3 4653120 100 46531.2 37485.5 35536 290957 40853.9 PushPop :FusionKernelRuntime::runSegmentsWithInputs +# 13.5 4099585 300 13665.3 11896.0 8647 197301 18602.1 PushPop :FusionKernelRuntime::runKernelWithInput +# 12.5 3810167 300 12700.6 11606.0 8426 196668 15305.2 PushPop :ExecutorDispatch::run2 +# 7.0 2114207 200 10571.0 10721.5 8123 45738 3065.6 PushPop :KernelExecutor::runFusion +# 5.3 1601371 100 16013.7 11992.5 11374 196441 25772.1 PushPop :ExprEvalExecutor::run +# 4.6 1406004 100 14060.0 10303.5 9756 194564 24892.8 PushPop :ExpressionEvaluator::evaluate +# 2.6 803300 200 4016.5 4956.5 2605 16355 1490.5 PushPop :fusion_executor::allocations::allocateOutputs +# 2.4 722700 200 3613.5 3430.5 2910 15725 1130.3 PushPop :KernelExecutor::runFusion::execute_kernel +# 2.2 666998 200 3335.0 3195.5 2708 14553 1040.9 PushPop :ExecutorRunFusion::cuLaunchKernel +# 0.8 257659 100 2576.6 707.5 483 182114 18138.7 PushPop :FusionExecutorCache::setCacheId +# 0.4 116565 200 582.8 602.0 420 2139 184.6 PushPop :KernelExecutor::computeArgs2 +# 0.3 96582 100 965.8 835.5 772 9563 880.2 PushPop :executor_utils::bindInputs +# 0.2 66419 300 221.4 157.0 127 2363 221.8 PushPop :ExecutorDispatch::isCompiled +# 0.1 29389 100 293.9 132.5 111 10088 1006.3 PushPop :FusionExecutorCache::getKernelRuntimeFor +# 0.1 29125 200 145.6 108.0 97 798 107.9 PushPop :KernelExecutor::runFusion::intermediates diff --git a/tests/python/llama_inf_tests/graph_1.py b/tests/python/llama_inf_tests/graph_1.py index dfa92968a47..4581586db20 100644 --- a/tests/python/llama_inf_tests/graph_1.py +++ b/tests/python/llama_inf_tests/graph_1.py @@ -137,18 +137,72 @@ def nvfuser_fusion_id1(fd : FusionDefinition) -> None : fd.execute(inputs) -# for _ in range(3): -# fd.execute(inputs) +for _ in range(3): + fd.execute(inputs) -# torch.cuda.synchronize() -# start = time.time() -# # Mark the profiling region -# torch.cuda.cudart().cudaProfilerStart() +torch.cuda.synchronize() +start = time.time() +# Mark the profiling region +torch.cuda.cudart().cudaProfilerStart() -# for _ in range(100): -# fd.execute(inputs) +for _ in range(100): + fd.execute(inputs) -# torch.cuda.cudart().cudaProfilerStop() -# torch.cuda.synchronize() -# end = time.time() -# print(end-start) \ No newline at end of file +torch.cuda.cudart().cudaProfilerStop() +torch.cuda.synchronize() +end = time.time() + +print((end-start)*1000, " ms") + + +# Before: +# 19.8 ms +# After: +# 10.6 ms + +# rm report* +# nsys profile -c cudaProfilerApi python tests/python/llama_inf_tests/graph_1.py +# nsys stats report1.nsys-rep + +# Before: +# Time (%) Total Time (ns) Instances Avg (ns) Med (ns) Min (ns) Max (ns) StdDev (ns) Style Range +# -------- --------------- --------- -------- -------- -------- -------- ----------- ------- ---------------------------------------------- +# 14.2 31791843 100 317918.4 268246.5 246507 762170 88138.8 PushPop :FusionExecutorCache::runFusionWithInputs +# 13.6 30602349 100 306023.5 261735.5 239889 737741 82786.1 PushPop :FusionKernelRuntime::runWithInputs +# 13.6 30480294 100 304802.9 260895.0 239116 735007 82461.4 PushPop :FusionKernelRuntime::runSegmentsWithInputs +# 13.0 29106369 1300 22389.5 18414.5 1815 266605 22601.1 PushPop :FusionKernelRuntime::runKernelWithInput +# 12.5 28090755 1300 21608.3 17832.0 1556 265963 22146.4 PushPop :ExecutorDispatch::run2 +# 8.1 18224542 500 36449.1 32308.0 16901 265512 24691.3 PushPop :KernelExecutor::runFusion +# 7.2 16182053 4100 3946.8 3299.5 258 152312 4334.0 PushPop :ExpressionEvaluator::evaluate +# 5.3 11797199 500 23594.4 18368.5 10877 209862 17543.9 PushPop :KernelExecutor::runFusion::execute_kernel +# 4.1 9212380 500 18424.8 12260.5 7177 200666 15103.3 PushPop :KernelExecutor::recomputeArgs +# 3.9 8816273 800 11020.3 11422.5 1369 159960 9005.4 PushPop :ExprEvalExecutor::run +# 1.5 3394339 500 6788.7 3757.5 1977 196369 9934.9 PushPop :fusion_executor::allocations::allocateOutputs +# 1.2 2593352 900 2881.5 1502.5 1213 186775 6863.1 PushPop :fusion_executor::allocations::allocateTensor +# 1.0 2179586 500 4359.2 3768.0 2935 185683 8309.3 PushPop :ExecutorRunFusion::cuLaunchKernel +# 0.6 1350809 1300 1039.1 838.0 420 6663 603.7 PushPop :executor_utils::bindInputs +# 0.2 498945 1300 383.8 172.0 102 2870 400.6 PushPop :ExecutorDispatch::isCompiled +# 0.1 159524 500 319.0 120.0 97 6098 431.0 PushPop :KernelExecutor::runFusion::intermediates +# 0.1 145263 100 1452.6 1330.0 907 5442 594.3 PushPop :FusionExecutorCache::setCacheId +# 0.0 37566 100 375.7 182.0 130 1791 350.7 PushPop :FusionExecutorCache::getKernelRuntimeFor + +# After: +# Time (%) Total Time (ns) Instances Avg (ns) Med (ns) Min (ns) Max (ns) StdDev (ns) Style Range +# -------- --------------- --------- -------- -------- -------- -------- ----------- ------- ---------------------------------------------- +# 16.1 19038957 100 190389.6 172709.0 145764 682753 77360.1 PushPop :FusionExecutorCache::runFusionWithInputs +# 15.2 18050839 100 180508.4 164976.0 140671 653055 72317.9 PushPop :FusionKernelRuntime::runWithInputs +# 15.2 17951998 100 179520.0 164040.5 139910 650172 72127.0 PushPop :FusionKernelRuntime::runSegmentsWithInputs +# 13.8 16310510 1300 12546.5 11597.5 1791 229198 15422.3 PushPop :FusionKernelRuntime::runKernelWithInput +# 13.2 15617300 1300 12013.3 11062.5 1542 228905 14374.1 PushPop :ExecutorDispatch::run2 +# 7.4 8714870 1900 4586.8 2470.0 255 214450 9636.4 PushPop :ExpressionEvaluator::evaluate +# 7.2 8564673 800 10705.8 10866.5 1348 200141 10054.6 PushPop :ExprEvalExecutor::run +# 5.5 6549775 500 13099.5 10053.5 6840 228580 19230.8 PushPop :KernelExecutor::runFusion +# 1.9 2290425 500 4580.9 3967.5 3054 200810 9071.5 PushPop :KernelExecutor::runFusion::execute_kernel +# 1.8 2092522 500 4185.0 3594.0 2841 200306 9052.8 PushPop :ExecutorRunFusion::cuLaunchKernel +# 1.3 1562864 500 3125.7 1808.5 1417 205151 9320.5 PushPop :fusion_executor::allocations::allocateOutputs +# 0.6 717665 900 797.4 669.0 433 11922 529.8 PushPop :executor_utils::bindInputs +# 0.4 423085 500 846.2 644.0 258 12710 764.5 PushPop :KernelExecutor::computeArgs2 +# 0.2 264123 1300 203.2 153.0 103 2587 140.1 PushPop :ExecutorDispatch::isCompiled +# 0.1 131127 100 1311.3 1206.5 829 6301 574.0 PushPop :FusionExecutorCache::setCacheId +# 0.1 101062 500 202.1 122.0 98 2736 238.3 PushPop :KernelExecutor::runFusion::intermediates +# 0.0 20536 100 205.4 158.0 115 1009 135.9 PushPop :FusionExecutorCache::getKernelRuntimeFor diff --git a/tests/python/llama_inf_tests/graph_2.py b/tests/python/llama_inf_tests/graph_2.py index 753115abd61..0ea9137d433 100644 --- a/tests/python/llama_inf_tests/graph_2.py +++ b/tests/python/llama_inf_tests/graph_2.py @@ -88,18 +88,71 @@ def nvfuser_fusion_id2(fd : FusionDefinition) -> None : fd.execute(inputs) -# for _ in range(3): -# fd.execute(inputs) +for _ in range(3): + fd.execute(inputs) -# torch.cuda.synchronize() -# start = time.time() -# # Mark the profiling region -# torch.cuda.cudart().cudaProfilerStart() +torch.cuda.synchronize() +start = time.time() +# Mark the profiling region +torch.cuda.cudart().cudaProfilerStart() -# for _ in range(100): -# fd.execute(inputs) +for _ in range(100): + fd.execute(inputs) -# torch.cuda.cudart().cudaProfilerStop() -# torch.cuda.synchronize() -# end = time.time() -# print(end-start) \ No newline at end of file +torch.cuda.cudart().cudaProfilerStop() +torch.cuda.synchronize() +end = time.time() +print((end-start)*1000, " ms") + +# Before: +# 18.9 ms +# After: +# 18.8 ms + + +# rm report* +# nsys profile -c cudaProfilerApi python tests/python/llama_inf_tests/graph_2.py +# nsys stats report1.nsys-rep + +# Before: +# Time (%) Total Time (ns) Instances Avg (ns) Med (ns) Min (ns) Max (ns) StdDev (ns) Style Range +# -------- --------------- --------- -------- -------- -------- -------- ----------- ------- ---------------------------------------------- +# 14.3 21273988 100 212739.9 191180.5 179102 647287 68515.0 PushPop :FusionExecutorCache::runFusionWithInputs +# 13.9 20711603 100 207116.0 185424.0 174989 625555 67241.1 PushPop :FusionKernelRuntime::runWithInputs +# 13.9 20634952 100 206349.5 184704.0 174362 623140 67108.8 PushPop :FusionKernelRuntime::runSegmentsWithInputs +# 13.1 19477550 900 21641.7 19736.5 5253 229134 19235.9 PushPop :FusionKernelRuntime::runKernelWithInput +# 12.8 18979906 900 21088.8 19402.5 5008 228699 18181.9 PushPop :ExecutorDispatch::run2 +# 9.1 13569155 2100 6461.5 3299.5 1250 188373 8072.4 PushPop :ExpressionEvaluator::evaluate +# 6.8 10071317 600 16785.5 16953.0 4816 226456 12582.7 PushPop :ExprEvalExecutor::run +# 5.8 8593835 300 28646.1 23748.5 18304 209021 24037.1 PushPop :KernelExecutor::runFusion +# 4.1 6042339 300 20141.1 17470.5 12833 200139 18266.7 PushPop :KernelExecutor::runFusion::execute_kernel +# 3.2 4802005 300 16006.7 13105.5 9464 195217 18063.1 PushPop :KernelExecutor::recomputeArgs +# 0.7 1083270 300 3610.9 3488.0 2803 9530 804.5 PushPop :ExecutorRunFusion::cuLaunchKernel +# 0.7 1066491 300 3555.0 2310.5 1934 173206 9901.1 PushPop :fusion_executor::allocations::allocateOutputs +# 0.7 1059251 900 1176.9 864.0 534 174544 5806.8 PushPop :executor_utils::bindInputs +# 0.5 753282 400 1883.2 1430.0 1237 169947 8427.1 PushPop :fusion_executor::allocations::allocateTensor +# 0.1 168892 900 187.7 147.0 103 1888 115.3 PushPop :ExecutorDispatch::isCompiled +# 0.1 154076 100 1540.8 1415.5 1012 6127 537.9 PushPop :FusionExecutorCache::setCacheId +# 0.0 51330 300 171.1 114.5 97 1893 169.0 PushPop :KernelExecutor::runFusion::intermediates +# 0.0 19650 100 196.5 156.0 117 845 109.3 PushPop :FusionExecutorCache::getKernelRuntimeFor + +# After: +# Time (%) Total Time (ns) Instances Avg (ns) Med (ns) Min (ns) Max (ns) StdDev (ns) Style Range +# -------- --------------- --------- -------- -------- -------- -------- ----------- ------- ---------------------------------------------- +# 16.0 16373962 100 163739.6 141242.5 134284 628006 71340.5 PushPop :FusionExecutorCache::runFusionWithInputs +# 15.1 15382598 100 153826.0 136014.5 129986 600299 64199.4 PushPop :FusionKernelRuntime::runWithInputs +# 15.0 15308089 100 153080.9 135307.0 129396 597501 64038.0 PushPop :FusionKernelRuntime::runSegmentsWithInputs +# 13.8 14094916 900 15661.0 16073.0 5213 251448 17412.1 PushPop :FusionKernelRuntime::runKernelWithInput +# 13.3 13579134 900 15087.9 15684.5 4944 251078 16245.8 PushPop :ExecutorDispatch::run2 +# 9.7 9923498 600 16539.2 16927.0 4741 250699 14459.9 PushPop :ExprEvalExecutor::run +# 9.4 9632237 900 10702.5 13644.5 1314 248454 13132.5 PushPop :ExpressionEvaluator::evaluate +# 3.3 3330448 300 11101.5 8723.0 6965 201143 18774.3 PushPop :KernelExecutor::runFusion +# 1.1 1129730 300 3765.8 3564.0 2899 11811 941.0 PushPop :KernelExecutor::runFusion::execute_kernel +# 1.0 1013555 300 3378.5 3257.5 2673 10550 813.0 PushPop :ExecutorRunFusion::cuLaunchKernel +# 0.9 917298 300 3057.7 1641.0 1451 193374 11094.4 PushPop :fusion_executor::allocations::allocateOutputs +# 0.5 545570 600 909.3 832.0 531 11018 553.9 PushPop :executor_utils::bindInputs +# 0.4 374638 900 416.3 150.0 101 199723 6652.2 PushPop :ExecutorDispatch::isCompiled +# 0.2 204834 100 2048.3 148.5 110 185215 18502.2 PushPop :FusionExecutorCache::getKernelRuntimeFor +# 0.2 166768 300 555.9 478.5 341 5239 394.4 PushPop :KernelExecutor::computeArgs2 +# 0.2 165881 100 1658.8 1535.0 1125 7627 674.0 PushPop :FusionExecutorCache::setCacheId +# 0.1 57354 300 191.2 111.0 97 2077 236.0 PushPop :KernelExecutor::runFusion::intermediates From fec59b54842e9cd1cf0c215b4fcdc5c701ec4b7c Mon Sep 17 00:00:00 2001 From: Christian Sarofeen Date: Sun, 2 Mar 2025 13:49:44 -0800 Subject: [PATCH 15/16] Revert changes from expr eval executor prototype. --- csrc/device_lower/pass/replace_size.cpp | 3 + csrc/device_lower/pass/replace_size.h | 19 - csrc/expr_evaluator.cpp | 19 +- csrc/expr_evaluator.h | 7 - csrc/ir/base_nodes.cpp | 44 ++ csrc/ir/nodes.cpp | 837 ++++++++++++++++++++++++ csrc/runtime/executor_dispatch.cpp | 1 - csrc/runtime/executor_utils.cpp | 8 +- tests/cpp/test_alias.cpp | 1 - tests/cpp/test_evaluator.cpp | 21 +- 10 files changed, 897 insertions(+), 63 deletions(-) diff --git a/csrc/device_lower/pass/replace_size.cpp b/csrc/device_lower/pass/replace_size.cpp index d4c2ea474aa..2cfe1405050 100644 --- a/csrc/device_lower/pass/replace_size.cpp +++ b/csrc/device_lower/pass/replace_size.cpp @@ -18,6 +18,7 @@ namespace nvfuser { +namespace { // Going to generate a map of tensor view root domain extents to reduce the // number used during lowering. For example if we have: // @@ -136,6 +137,8 @@ std::unordered_map getSimplificationMap(Fusion* fusion) { return simplification_map; } +} // namespace + void replaceSymbolicSizes(Fusion* fusion) { FUSER_PERF_SCOPE("GpuLower::Lower::replaceSymbolicSizes"); std::unordered_map tensor_dim_map; diff --git a/csrc/device_lower/pass/replace_size.h b/csrc/device_lower/pass/replace_size.h index aab690f1df7..ca874ab836d 100644 --- a/csrc/device_lower/pass/replace_size.h +++ b/csrc/device_lower/pass/replace_size.h @@ -21,23 +21,4 @@ namespace nvfuser { // tensors to reference the runtime structure containing sizes. void replaceSymbolicSizes(Fusion*); -// Going to generate a map of tensor view root domain extents to reduce the -// number used during lowering. For example if we have: -// -// T2[i0, i1] = T1[i0, i1] + T2[i2, i3] -// -// We know it would be safe to use: -// -// T2[i0, i1] = T1[i0, i1] + T2[i0, i1] -// -// And that way we don't generate T2.size[0] and T2.size[1], instead we will -// reuse T1.size[0] and T1.size[1] -// This is important when doing CSE as T2 and T1 would otherwise look like -// they're using different values, even though we know they're the same -// -// There's some duplicate logic here that's in computeAt map, but it's not so -// concice there to pull out. May want to consider making this mapping its own -// class especially as it may be useful during scheduling. -std::unordered_map getSimplificationMap(Fusion* fusion); - } // namespace nvfuser diff --git a/csrc/expr_evaluator.cpp b/csrc/expr_evaluator.cpp index cf5646df21e..a2ebccfb7b3 100644 --- a/csrc/expr_evaluator.cpp +++ b/csrc/expr_evaluator.cpp @@ -134,7 +134,6 @@ void ExpressionEvaluator::bindTensorDomain( const TensorView* tv, const at::Tensor& t, const bool evaluate_validate) { - FUSER_PERF_SCOPE("ExpressionEvaluator::bindTensorDomain"); auto logical_domain = TensorDomain::noReductions(tv->getLogicalDomain()); NVF_ERROR( t.dim() == (int64_t)logical_domain.size(), @@ -175,17 +174,10 @@ void ExpressionEvaluator::bindTensorDomain( } } -void ExpressionEvaluator::unsafeBind( - const Val* value, - PolymorphicValue concrete_value) { - known_values_[value] = concrete_value; -} - void ExpressionEvaluator::bind_( const Val* value, PolymorphicValue concrete_value, bool evaluate_validate) { - FUSER_PERF_SCOPE("ExpressionEvaluator::bind_"); using namespace PolymorphicValue_functions; NVF_CHECK(concrete_value.hasValue(), "Cannot bind to undefined value"); if (value->isConst()) { @@ -265,10 +257,6 @@ PolymorphicValue ExpressionEvaluator::evaluate(const Val* value) const { const PolymorphicValue& ExpressionEvaluator::evaluate( const Val* value, std::unordered_map& known_values) const { - // FUSER_PERF_SCOPE("ExpressionEvaluator::evaluate"); - // It's tempting to time this function, the issue is it's a recursive function - // so timings produced by it can be accumulatively longer than the actual time - // spent if (precomputed_values_ && precomputed_values_->hasValidValues()) { if (precomputed_values_->getMaybeValueFor(value).hasValue()) { return precomputed_values_->getMaybeValueFor(value); @@ -279,6 +267,7 @@ const PolymorphicValue& ExpressionEvaluator::evaluate( getValue(value, known_values); if (!maybe_concrete_value.get().hasValue()) { if (auto def = value->definition()) { + FUSER_PERF_SCOPE("ExpressionEvaluator::evaluate"); auto outputs = def->evaluate(*this, known_values); for (auto i : c10::irange(def->outputs().size())) { known_values[def->output(i)] = std::move(outputs[i]); @@ -286,12 +275,6 @@ const PolymorphicValue& ExpressionEvaluator::evaluate( maybe_concrete_value = getValue(value, known_values); } } - // TODO: Evaluate if an error like below could work - // NVF_ERROR( - // maybe_concrete_value.get().hasValue(), - // "Error evaluating a value in expression evaluator. Likely ", - // value->toString(), - // " needs to be bound to a value."); return maybe_concrete_value; } diff --git a/csrc/expr_evaluator.h b/csrc/expr_evaluator.h index 0a79defaaf8..b6c8e1857ea 100644 --- a/csrc/expr_evaluator.h +++ b/csrc/expr_evaluator.h @@ -22,7 +22,6 @@ namespace nvfuser { class PrecomputedValues; -class ExprEvalExecutor; //! Calculate Fusion IR expressions class ExpressionEvaluator { @@ -92,12 +91,6 @@ class ExpressionEvaluator { ExpressionEvaluator clone(IrCloner& ir_cloner) const; - protected: - friend ExprEvalExecutor; - // Direct access to adding values to known_values_ without going through bind_ - // which does validation and will also bind all tensor domain information. - void unsafeBind(const Val* value, PolymorphicValue concrete_value); - private: void bind_( const Val* value, diff --git a/csrc/ir/base_nodes.cpp b/csrc/ir/base_nodes.cpp index 88b7e93a18c..6e7e53e0f4d 100644 --- a/csrc/ir/base_nodes.cpp +++ b/csrc/ir/base_nodes.cpp @@ -6,6 +6,7 @@ */ // clang-format on #include +#include #include #include #include @@ -199,6 +200,20 @@ bool Val::isConstInt() const { return ir_utils::dependenciesSatisfied(this) && isIntegralScalar(); } +PolymorphicValue Val::evaluate() { + if (this->value().hasValue()) { + return this->value(); + } + + ExpressionEvaluator ee; + auto evaluated_val = ee.evaluate(this); + NVF_ERROR( + evaluated_val.hasValue(), + "Detected a const value but failed to infer its value: ", + toInlineString()); + return evaluated_val; +} + bool Val::isZero() const { return value().hasValue() && (bool)(value() == 0.0); } @@ -361,4 +376,33 @@ Expr* Expr::withWritePredicate(kir::Predicate* predicate) { return result; } +std::vector Expr::evaluate( + const ExpressionEvaluator& ee, + const std::vector& inputs) const { + NVF_THROW( + "`evaluate` method for expression ", + getOpString(), + " is not defined. ", + "Please override the evaluate method"); +} + +std::vector Expr::evaluate( + const ExpressionEvaluator& ee, + std::unordered_map& known_values) const { + std::vector expr_inputs; + expr_inputs.reserve(inputs().size()); + for (auto inp : inputs()) { + const auto& eval_i = ee.evaluate(inp, known_values); + if (!eval_i.hasValue()) { + return {std::monostate{}}; + } + expr_inputs.emplace_back(eval_i); + } + return this->evaluate(ee, expr_inputs); +} + +void Expr::addDataAttribute(PolymorphicValue attr) { + addAttribute(IrBuilder::createInContainer(container(), std::move(attr))); +} + } // namespace nvfuser diff --git a/csrc/ir/nodes.cpp b/csrc/ir/nodes.cpp index 3377e703d22..05d89df7e0a 100644 --- a/csrc/ir/nodes.cpp +++ b/csrc/ir/nodes.cpp @@ -19,6 +19,7 @@ #include #include #include +#include #include #include #include @@ -69,6 +70,20 @@ std::string FullOp::toInlineString(int indent_size) const { NVF_CHECK(false, "Tensor op can not be printed inline"); } +std::vector FullOp::evaluate( + const ExpressionEvaluator& ee, + const std::vector& inputs) const { + std::vector shape; + for (auto i : c10::irange(inputs.size() - 1)) { + shape.push_back(inputs.at(i).as()); + } + DataType dtype = getFillValue()->getDataType().value(); + const auto options = + at::TensorOptions().device(at::kCUDA).dtype(data_type_to_aten(dtype)); + using namespace PolymorphicValue_functions; + return {at::full(shape, toScalar(inputs.back()), options)}; +} + NVFUSER_DEFINE_CLONE_AND_CREATE(FullOp) SelectOp::SelectOp( @@ -104,6 +119,15 @@ IterDomain* SelectOp::getIndexedID() const { .at(dim()); } +std::vector SelectOp::evaluate( + const ExpressionEvaluator& ee, + const std::vector& inputs) const { + const auto& in = inputs.at(0).as(); + int64_t dimension = dim(); + int64_t index = (int64_t)inputs.at(1); + return {in.select(dimension, index)}; +} + NVFUSER_DEFINE_CLONE_AND_CREATE(SelectOp) IndexSelectOp::IndexSelectOp( @@ -143,6 +167,15 @@ IterDomain* IndexSelectOp::getConsumerOfIndexedID() const { return ir_utils::getTvOutput(this)->getLogicalDomain().at(dim()); } +std::vector IndexSelectOp::evaluate( + const ExpressionEvaluator& ee, + const std::vector& inputs) const { + const auto& in = inputs.at(0).as(); + int64_t dimension = dim(); + const auto& indices = inputs.at(1).as().squeeze(); + return {at::index_select(in, dimension, indices)}; +} + NVFUSER_DEFINE_CLONE_AND_CREATE(IndexSelectOp) TorchGatherOp::TorchGatherOp( @@ -187,6 +220,19 @@ IterDomain* TorchGatherOp::getConsumerOfIndexedID() const { return ir_utils::getTvOutput(this)->getLogicalDomain().at(dim()); } +std::vector TorchGatherOp::evaluate( + const ExpressionEvaluator& ee, + const std::vector& inputs) const { + const auto& input = inputs.at(0).as(); + const auto& index = inputs.at(1).as(); + auto dimension = dim(); + if (exactSizes()) { + return {at::take_along_dim(input, index, dimension)}; + } else { + return {at::gather(input, dimension, index)}; + } +} + NVFUSER_DEFINE_CLONE_AND_CREATE(TorchGatherOp) ScatterOp::ScatterOp( @@ -225,6 +271,16 @@ IterDomain* ScatterOp::getIndexedID() const { return ir_utils::getTvOutput(this)->getLogicalDomain().at(dim()); } +std::vector ScatterOp::evaluate( + const ExpressionEvaluator& ee, + const std::vector& inputs) const { + const auto& input = inputs.at(0).as(); + const auto& index = inputs.at(1).as(); + const auto& src = inputs.at(2).as(); + auto dimension = dim(); + return {at::scatter(input, dimension, index, src)}; +} + NVFUSER_DEFINE_CLONE_AND_CREATE(ScatterOp) IotaOp::IotaOp( @@ -258,6 +314,31 @@ std::string IotaOp::toInlineString(int indent_size) const { NVF_CHECK(false, "Tensor op can not be printed inline"); } +std::vector IotaOp::evaluate( + const ExpressionEvaluator& ee, + const std::vector& inputs) const { + const auto options = + at::TensorOptions().device(at::kCUDA).dtype(data_type_to_aten(dtype())); + int64_t length = (int64_t)inputs.at(0); + + if (isIntegralType(dtype())) { + int64_t start = (int64_t)inputs.at(1); + int64_t step = (int64_t)inputs.at(2); + int64_t end = start + step * length; + return {at::arange(start, end, step, options)}; + } else if (isFloatingPointType(dtype())) { + double start = (double)inputs.at(1); + double step = (double)inputs.at(2); + // Due to rounding error, it can be hard to guarantee the size of + // the output of arange to be exactly length, so we generate a + // larger tensor and truncate it to length. + double end = start + step * ((double)length + 1); + return {at::arange(start, end, step, options).narrow(0, 0, length)}; + } else { + NVF_THROW("Unsupported dtype in IotaOp evaluator: ", dtype()); + } +} + NVFUSER_DEFINE_CLONE_AND_CREATE(IotaOp) EyeOp::EyeOp(IrBuilderPasskey passkey, Val* out, DataType dtype) @@ -285,6 +366,19 @@ std::string EyeOp::toString(int indent_size) const { std::string EyeOp::toInlineString(int indent_size) const { NVF_CHECK(false, "Tensor op can not be printed inline"); } +std::vector EyeOp::evaluate( + const ExpressionEvaluator& ee, + const std::vector& inputs) const { + const auto options = + at::TensorOptions().device(at::kCUDA).dtype(data_type_to_aten(dtype())); + int64_t nrows = (int64_t)inputs.at(0); + if (inputs.size() > 1) { + int64_t ncols = (int64_t)inputs.at(1); + return {at::eye(nrows, ncols, options)}; + } else { + return {at::eye(nrows, options)}; + } +} NVFUSER_DEFINE_CLONE_AND_CREATE(EyeOp) @@ -295,6 +389,133 @@ UnaryOp::UnaryOp(IrBuilderPasskey passkey, UnaryOpType type, Val* out, Val* in) addDataAttribute(type); } +std::vector UnaryOp::evaluate( + const ExpressionEvaluator& ee, + const std::vector& inputs) const { + using namespace PolymorphicValue_functions; + + const auto& in = inputs.at(0); + if (!in.hasValue()) { + return {std::monostate{}}; + } + + switch (getUnaryOpType()) { + case UnaryOpType::Neg: + return {-in}; + case UnaryOpType::Cast: + if (in.is()) { + return {PolymorphicValue( + in.as().to(data_type_to_aten(out()->dtype())))}; + } else if (isIntegralType(*out()->getDataType())) { + return {PolymorphicValue((int64_t)in)}; + } else if (isFloatingPointType(*out()->getDataType())) { + return {PolymorphicValue((double)in)}; + } else if (out()->getDataType() == DataType::Bool) { + return {PolymorphicValue((bool)in)}; + } else if (isComplexType(*out()->getDataType())) { + return {PolymorphicValue((std::complex)in)}; + } else { + NVF_THROW("dtype not supported in evaluator: ", *out()->getDataType()); + } + case UnaryOpType::Reciprocal: + return {1.0 / in}; + break; + case UnaryOpType::Abs: + return {abs(in)}; + break; + case UnaryOpType::LogicalNot: + return {!in}; + break; + case UnaryOpType::BitwiseNot: + return {~in}; + break; + case UnaryOpType::Erf: + return {erf(in)}; + break; + case UnaryOpType::ToUnsignedSmemAddr: + return {(int64_t)(unsigned)in}; + break; + case UnaryOpType::AdjustPartialLdMatrixAddrInTuring8: + case UnaryOpType::AdjustPartialLdMatrixAddrInTuring16: + return {in}; + break; + case UnaryOpType::Dereference: + if (*out()->getDataType() == DataType::Float) { + return {PolymorphicValue((double)*(float*)in)}; + } else { + NVF_THROW("dtype not supported in evaluator: ", *out()->getDataType()); + } + break; + case UnaryOpType::Sigmoid: + return {in.as().sigmoid()}; + break; + case UnaryOpType::Tanh: + return {in.as().tanh()}; + break; + case UnaryOpType::Relu: + return {at::relu(in.as())}; + break; + case UnaryOpType::Gelu: + return {at::gelu(in.as())}; + break; + case UnaryOpType::Exp: + return {at::exp(in.as())}; + break; + case UnaryOpType::Sin: + return {in.as().sin()}; + break; + case UnaryOpType::Signbit: + return {signbit(in)}; + break; + case UnaryOpType::Cos: + return {in.as().cos()}; + break; + case UnaryOpType::BitCast: + NVF_CHECK( + dataTypeSize(input(0)->dtype()) == dataTypeSize(out()->dtype()), + "BitCast only works for types of the same size"); + if (isComplexType(input(0)->dtype()) && + std::holds_alternative(out()->dtype().type)) { + // view_as_real case. + auto vec_type = std::get(out()->dtype().type); + auto inp_scalar_type = getTypeFromComplexType(input(0)->dtype()); + NVF_CHECK( + *vec_type.type == inp_scalar_type, + "Output type must be the same as the scalar type of the complex input."); + NVF_CHECK( + vec_type.size == 2, + "Expected output to be array of size 2, found array of size ", + vec_type.size); + return {in.as()}; + } else { + return {in.as().view(data_type_to_aten(out()->dtype()))}; + } + break; + case UnaryOpType::Rsqrt: + return {in.as().rsqrt()}; + break; + case UnaryOpType::Real: + return {at::real(in.as())}; + break; + case UnaryOpType::Imag: + return {at::imag(in.as())}; + break; + case UnaryOpType::Tan: + return {in.as().tan()}; + break; + case UnaryOpType::IsFinite: + return {at::isfinite(in.as())}; + break; + default: + NVF_CHECK( + false, + "Unexpected operator type ", + getUnaryOpType(), + " in ", + toString()); + } +} + void UnaryOp::printHelper(std::stringstream& ss, std::string input) const { auto op_type = getUnaryOpType(); @@ -537,6 +758,38 @@ TernaryOp::TernaryOp( addDataAttribute(type); } +std::vector TernaryOp::evaluate( + const ExpressionEvaluator& ee, + const std::vector& inputs) const { + using namespace PolymorphicValue_functions; + const auto& a = inputs.at(0); + const auto& b = inputs.at(1); + const auto& c = inputs.at(2); + switch (getTernaryOpType()) { + case TernaryOpType::Clamp: + return {std::min(std::max(a, b), c)}; + break; + case TernaryOpType::Lerp: + // This is the same lerp computed in helpers.cu + // https://math.stackexchange.com/a/1798323 + return {(c < 0.5) ? a + c * (b - a) : b - (b - a) * (1.0 - c)}; + break; + case TernaryOpType::Threshold: + return {(a <= b) ? c : a}; + break; + case TernaryOpType::Where: + return {a.as() ? b : c}; + break; + default: + NVF_CHECK( + false, + "Unexpected operator type: ", + getTernaryOpType(), + " in ", + toString()); + } +} + void TernaryOp::printHelper( std::stringstream& ss, int indent_size, @@ -637,6 +890,12 @@ std::string ArrayConstruct::toInlineString(int indent_size) const { return ss.str(); } +std::vector ArrayConstruct::evaluate( + const ExpressionEvaluator& ee, + const std::vector& inputs) const { + return {PolymorphicValue(inputs)}; +} + NVFUSER_DEFINE_CLONE_AND_CREATE(ArrayConstruct) ReverseArray::ReverseArray(IrBuilderPasskey passkey, Val* output, Val* input) @@ -678,6 +937,16 @@ std::string ReverseArray::toInlineString(int indent_size) const { return ss.str(); } +std::vector ReverseArray::evaluate( + const ExpressionEvaluator& ee, + const std::vector& inputs) const { + NVF_ERROR(inputs.size() == 1, "ReverseArray expects 1 input"); + PolymorphicValue array = inputs.at(0); + auto& vec = array.as(); + std::reverse(vec.begin(), vec.end()); + return {std::move(array)}; +} + NVFUSER_DEFINE_CLONE_AND_CREATE(ReverseArray) GetItem::GetItem(IrBuilderPasskey passkey, Val* output, Val* array, Val* index) @@ -704,6 +973,13 @@ std::string GetItem::toInlineString(int indent_size) const { return ss.str(); } +std::vector GetItem::evaluate( + const ExpressionEvaluator& ee, + const std::vector& inputs) const { + NVF_ERROR(inputs.size() == 2, "GetItem expects 2 inputs"); + return {PolymorphicValue(inputs.at(0)[inputs.at(1)])}; +} + NVFUSER_DEFINE_CLONE_AND_CREATE(GetItem) StructConstruct::StructConstruct( @@ -759,6 +1035,22 @@ std::string StructConstruct::toInlineString(int indent_size) const { return ss.str(); } +std::vector StructConstruct::evaluate( + const ExpressionEvaluator& ee, + const std::vector& inputs) const { + NVF_ERROR( + this->inputs().size() == inputs.size(), + "StructConstruct expects ", + this->inputs().size(), + " inputs"); + PolymorphicValue struct_ = + std::get(output(0)->dtype().type).create(); + for (int64_t i : c10::irange((int64_t)inputs.size())) { + struct_->*attribute(i) = inputs.at(i); + } + return {std::move(struct_)}; +} + NVFUSER_DEFINE_CLONE_AND_CREATE(StructConstruct) GetAttr::GetAttr( @@ -789,6 +1081,13 @@ std::string GetAttr::toInlineString(int indent_size) const { return ss.str(); } +std::vector GetAttr::evaluate( + const ExpressionEvaluator& ee, + const std::vector& inputs) const { + NVF_ERROR(inputs.size() == 1, "GetAttr expects 1 input"); + return {inputs.at(0)->*attr()}; +} + NVFUSER_DEFINE_CLONE_AND_CREATE(GetAttr) GetMetaData::GetMetaData(IrBuilderPasskey passkey, Val* output, Val* input) @@ -835,6 +1134,14 @@ std::string TensorConstruct::toInlineString(int indent_size) const { NVF_CHECK(false, "Tensor op can not be printed inline"); } +std::vector TensorConstruct::evaluate( + const ExpressionEvaluator& ee, + const std::vector& inputs) const { + NVF_ERROR(inputs.size() == 1, "TensorConstruct expects 1 input"); + using namespace PolymorphicValue_functions; + return {toTensor(inputs.at(0))}; +} + NVFUSER_DEFINE_CLONE_AND_CREATE(TensorConstruct) RNGOp::RNGOp( @@ -984,6 +1291,26 @@ std::string BroadcastOp::toInlineString(int indent_size) const { NVF_CHECK(false, "Tensor op can not be printed inline"); } +std::vector BroadcastOp::evaluate( + const ExpressionEvaluator& ee, + const std::vector& inputs) const { + NVF_ERROR( + inputs.size() == 1, + "BroadcastOp expects exactly 1 input, but received ", + inputs.size()); + std::vector out_shape; + const auto& in = inputs.at(0).as(); + int64_t idx = 0; + for (bool b : getBroadcastDimFlags()) { + if (b) { + out_shape.push_back(1); + } else { + out_shape.push_back(in.sizes()[idx++]); + } + } + return {in.view(out_shape)}; +} + NVFUSER_DEFINE_CLONE_AND_CREATE(BroadcastOp) SqueezeOp::SqueezeOp( @@ -1070,6 +1397,35 @@ std::string SqueezeOp::toInlineString(int indent_size) const { NVF_CHECK(false, "Tensor op can not be printed inline"); } +std::vector SqueezeOp::evaluate( + const ExpressionEvaluator& ee, + const std::vector& inputs) const { + NVF_ERROR( + inputs.size() == 1, + "SqueezeOp expects exactly 1 input, but received ", + inputs.size()); + std::vector out_shape; + const auto& in = inputs.at(0).as(); + const auto& is_squeeze_dims = getSqueezeDimFlags(); + NVF_ERROR( + (int64_t)is_squeeze_dims.size() == in.dim(), + "The dimensions of input tensor and does not match with is_squeeze_dims"); + at::Tensor out = in; + for (int64_t i : c10::irange((int64_t)is_squeeze_dims.size())) { + if (is_squeeze_dims[i]) { + if (in.stride(i) == 0) { + // If the input dimension is expanded in this dimension, undo the expand + // by slicing. This ensures that any broadcast dimensions will be + // unexpanded when we do the final call to view() + out = out.slice(i, 0, 1); + } + } else { + out_shape.push_back(in.sizes()[i]); + } + } + return {out.view(out_shape)}; +} + void SqueezeOp::checkConcretization(Val* old_val, Val* new_val) const { Expr::checkConcretization(old_val, new_val); // does nullptr, vtype checks NVF_CHECK( @@ -1164,6 +1520,43 @@ std::string ReductionOp::toInlineString(int indent_size) const { NVF_CHECK(false, "Tensor op can not be printed inline"); } +std::vector ReductionOp::evaluate( + const ExpressionEvaluator& ee, + const std::vector& inputs) const { + const auto& input = inputs.at(0).as(); + const auto output = out()->as(); + + NVF_ERROR( + !output->hasRoot(), + "Evaluation for rFactored reductions is not supported."); + + std::vector reduction_axes; + for (const auto i : c10::irange(int64_t(output->getLogicalDomain().size()))) { + auto ax = output->getLogicalDomain().at(i); + if (ax->isReduction()) { + reduction_axes.push_back(i); + } + } + switch (getReductionOpType()) { + case BinaryOpType::Add: + return {at::sum(input, reduction_axes)}; + break; + case BinaryOpType::Max: + return {at::amax(input, reduction_axes)}; + break; + case BinaryOpType::Min: + return {at::amin(input, reduction_axes)}; + break; + default: + NVF_CHECK( + false, + "Unexpected operator type: ", + getReductionOpType(), + " in ", + toString()); + } +} + NVFUSER_DEFINE_CLONE_AND_CREATE(ReductionOp) GroupedReductionOp::GroupedReductionOp( @@ -1218,6 +1611,46 @@ int GroupedReductionOp::getExprIndexOfOutput(Val* output_val) const { NVF_THROW("Not an output, ", output_val->toString(), ", of ", toString()); } +std::vector GroupedReductionOp::evaluate( + const ExpressionEvaluator& ee, + const std::vector& inputs) const { + const auto num_reductions = numHorizontallyGroupedExprs(); + std::vector grouped_reduction_out; + grouped_reduction_out.reserve(num_reductions); + for (const auto i : c10::irange(num_reductions)) { + const auto& in_tensor = inputs.at(i).as(); + const auto out_tv = output(i)->as(); + NVF_ERROR( + !out_tv->hasRoot(), + "Evaluation for rFactored reductions is not supported."); + + std::vector reduction_axes; + for (const auto id : + c10::irange(int64_t(out_tv->getLogicalDomain().size()))) { + auto ax = out_tv->getLogicalDomain().at(id); + if (ax->isReduction()) { + reduction_axes.push_back(id); + } + } + switch (getReductionOpType(i)) { + case BinaryOpType::Add: + grouped_reduction_out.emplace_back(at::sum(in_tensor, reduction_axes)); + break; + case BinaryOpType::Max: + grouped_reduction_out.emplace_back(at::amax(in_tensor, reduction_axes)); + break; + default: + NVF_CHECK( + false, + "Unexpected operator type: ", + getReductionOpType(i), + " in ", + toString()); + } + } + return grouped_reduction_out; +} + NVFUSER_DEFINE_CLONE_AND_CREATE(GroupedReductionOp) std::optional WelfordTriplet::getNameOf( @@ -1399,6 +1832,32 @@ std::string WelfordOp::toInlineString(int indent_size) const { NVF_CHECK(false, "Tensor op can not be printed inline"); } +std::vector WelfordOp::evaluate( + const ExpressionEvaluator& ee, + const std::vector& inputs) const { + NVF_ERROR( + !hasInit(), + "Evaluation for WelfordOp is not implemented for non-empty initial values."); + const auto& in_tensor = inputs.at(0).as(); + const auto out_tv = out()->as(); + NVF_ERROR( + !out_tv->hasRoot(), + "Evaluation for WelfordOp is not supported when output is rFactored."); + + int64_t N = 1; + std::vector reduction_axes; + for (const auto i : c10::irange(int64_t(out_tv->getLogicalDomain().size()))) { + auto ax = out_tv->getLogicalDomain().at(i); + if (ax->isReduction()) { + reduction_axes.push_back(i); + N *= in_tensor.size(i); + } + } + const auto [in_var, in_avg] = + at::var_mean(in_tensor, reduction_axes, false, false); + return {in_avg, in_var * N, N}; +} + NVFUSER_DEFINE_CLONE_AND_CREATE(WelfordOp) GroupedWelfordOp::GroupedWelfordOp( @@ -1678,6 +2137,17 @@ std::string ExpandOp::toInlineString(int indent_size) const { NVF_CHECK(false, "Tensor op can not be printed inline"); } +std::vector ExpandOp::evaluate( + const ExpressionEvaluator& ee, + const std::vector& inputs) const { + const auto& in = inputs.at(0).as(); + std::vector expanded_size; + for (auto i : c10::irange(1, inputs.size())) { + expanded_size.push_back((int64_t)inputs.at(i)); + } + return {in.expand(expanded_size)}; +} + NVFUSER_DEFINE_CLONE_AND_CREATE(ExpandOp) RepeatOp::RepeatOp(IrBuilderPasskey passkey, TensorView* out, TensorView* in) @@ -1727,6 +2197,37 @@ std::string RepeatOp::toInlineString(int indent_size) const { NVF_CHECK(false, "Tensor op can not be printed inline"); } +std::vector RepeatOp::evaluate( + const ExpressionEvaluator& ee, + const std::vector& inputs) const { + NVF_ERROR( + inputs.size() == 1, + "RepeatOp expects exactly 1 input, but received ", + inputs.size()); + auto tensor = inputs.at(0).as(); + std::vector multipliers; + multipliers.reserve(out()->getLogicalDomain().size()); + const auto c2p = + PairwiseLogicalDomainMap(in(), out()).mapConsumerToProducer(); + for (const auto i : c10::irange(out()->getLogicalDomain().size())) { + auto out_id = out()->getLogicalDomain().at(i); + auto inp_id = c2p.at(out_id); + auto out_extent = ee.evaluate(out_id->extent()).as(); + auto inp_extent = ee.evaluate(inp_id->extent()).as(); + NVF_ERROR( + out_extent % inp_extent == 0, + "For dimension ", + i, + ", the output extent (", + out_extent, + " should be a multiple of the input extent (", + inp_extent, + ")."); + multipliers.push_back(out_extent / inp_extent); + } + return {tensor.repeat(multipliers)}; +} + NVFUSER_DEFINE_CLONE_AND_CREATE(RepeatOp) ViewAsScalar::ViewAsScalar( @@ -1752,6 +2253,13 @@ std::string ViewAsScalar::toInlineString(int indent_size) const { NVF_CHECK(false, "Tensor op can not be printed inline"); } +std::vector ViewAsScalar::evaluate( + const ExpressionEvaluator& ee, + const std::vector& inputs) const { + const at::Tensor& in = inputs.at(0).as(); + return {at::view_as_real(in)}; +} + NVFUSER_DEFINE_CLONE_AND_CREATE(ViewAsScalar) ViewOp::ViewOp(IrBuilderPasskey passkey, Val* out, Val* in) : Expr(passkey) { @@ -1778,6 +2286,33 @@ std::string ViewOp::toInlineString(int indent_size) const { NVF_CHECK(false, "Tensor op can not be printed inline"); } +std::vector ViewOp::evaluate( + const ExpressionEvaluator& ee, + const std::vector& inputs) const { + NVF_ERROR(inputs.size() == 1); + const at::Tensor& in_tensor = inputs[0].as(); + + const std::vector& out_logical = out()->getLogicalDomain(); + std::vector out_shape; + out_shape.reserve(out_logical.size()); + for (IterDomain* id : out_logical) { + if (id->isDeviceDim()) { + out_shape.push_back(1); + } else { + out_shape.push_back( + ee.evaluate(id->getMaybeExpandedExtent()).as()); + } + } + + // TODO: check allocation domain and contiguity. + + // Use `at::Tensor::reshape` instead of `at::Tensor::view` because `ViewOp` + // doesn't always produce an alias. For example, when merging an expanded + // `IterType::Broadcast` and an `IterType::Iteration`, `ViewOp` has to realize + // the expand. + return {in_tensor.reshape(out_shape)}; +} + NVFUSER_DEFINE_CLONE_AND_CREATE(ViewOp) LoadStoreOp::LoadStoreOp( @@ -1811,6 +2346,27 @@ LoadStoreOp::LoadStoreOp( addDataAttribute(cache_op); } +std::vector LoadStoreOp::evaluate( + const ExpressionEvaluator& ee, + const std::vector& inputs) const { + if (TensorView* out_tv = dynamic_cast(out())) { + if (out_tv->hasRoot()) { + std::optional> permutation = + ir_utils::computePermutation( + out_tv->getRootDomain(), out_tv->getLogicalDomain()); + NVF_ERROR( + permutation.has_value(), + "The logical domain of a Set.Permute is supposed to be a permutation of the root domain: ", + out_tv->toString()); + NVF_ERROR(inputs.size() == 1); + at::Tensor in_tensor = inputs[0].as(); + at::Tensor out_tensor = in_tensor.permute(*permutation); + return {out_tensor}; + } + } + return inputs; +} + std::string LoadStoreOp::toString(int indent_size) const { std::stringstream ss; std::string optype = load_store_type2string(opType()); @@ -3753,6 +4309,36 @@ std::pair PadOp::getPadWidths(int64_t axis) const { (*(getPadWidthInputBegin() + offset_odd))->as()); } +std::vector PadOp::evaluate( + const ExpressionEvaluator& ee, + const std::vector& inputs) const { + const auto& in = inputs.at(0).as(); + + std::vector pad_widths; + auto pad_width_offset = getPadWidthInputOffset(); + auto num_dims = in.dim(); + + for (auto i = num_dims - 1; i > -1; i--) { + auto left_pad = (int64_t)inputs.at(pad_width_offset + 2 * i); + auto right_pad = (int64_t)inputs.at(pad_width_offset + 2 * i + 1); + pad_widths.push_back(left_pad); + pad_widths.push_back(right_pad); + } + + if (isComplexType(*out()->getDataType())) { + std::complex value = + static_cast>(inputs.at(1)); + auto real = at::real(in); + auto imag = at::imag(in); + auto padded_real = at::pad(real, pad_widths, "constant", value.real()); + auto padded_imag = at::pad(imag, pad_widths, "constant", value.imag()); + return {at::complex(padded_real, padded_imag)}; + } else { + double value = static_cast(inputs.at(1)); + return {at::pad(in, pad_widths, "constant", value)}; + } +} + SliceOp::SliceOp( IrBuilderPasskey passkey, TensorView* out, @@ -3821,6 +4407,22 @@ std::vector SliceOp::getRanges() const { return ranges; } +std::vector SliceOp::evaluate( + const ExpressionEvaluator& ee, + const std::vector& inputs) const { + const auto& in = inputs.at(0).as(); + std::vector ranges; + auto ranges_offset = getRangeInputOffset(); + auto num_dims = in.dim(); + for (const auto i : c10::irange(num_dims)) { + auto start = (int64_t)inputs.at(ranges_offset + 3 * i); + auto stop = (int64_t)inputs.at(ranges_offset + 3 * i + 1); + auto step = (int64_t)inputs.at(ranges_offset + 3 * i + 2); + ranges.emplace_back(at::indexing::Slice(start, stop, step)); + } + return {in.index(ranges)}; +} + CatOp::CatOp( IrBuilderPasskey passkey, Val* out, @@ -3917,6 +4519,24 @@ Val* CatOp::getPred(int input_idx) const { return pred; } +std::vector CatOp::evaluate( + const ExpressionEvaluator& ee, + std::unordered_map& known_values) const { + // CatOp is preceded by a PadOp internally. + // For ATen evaluation, directly compute the unpadded inputs. + std::vector unpadded_inputs; + unpadded_inputs.reserve(inputs().size()); + int64_t concat_dim = concatenatedDim(); + for (Val* inp : inputs()) { + NVF_CHECK( + inp->definition() != nullptr && inp->definition()->isA(), + "Expected CatOp to be preceded by a PadOp."); + auto eval_i = ee.evaluate(inp->definition()->input(0), known_values); + unpadded_inputs.push_back(eval_i.as()); + } + return {at::cat(unpadded_inputs, concat_dim)}; +} + MatmulOp::MatmulOp(IrBuilderPasskey passkey, Val* out, Val* in_a, Val* in_b) : Expr(passkey) { addOutput(out); @@ -4127,6 +4747,106 @@ std::string SdpaFwdOp::toInlineString(int indent_size) const { NVF_CHECK(false, "Tensor op can not be printed inline"); } +std::vector SdpaFwdOp::evaluate( + const ExpressionEvaluator& ee, + const std::vector& inputs) const { + auto query = inputs.at(0).as(); + auto key = inputs.at(1).as(); + auto value = inputs.at(2).as(); + + const auto dropout_p = inputs.at(3).as(); + const auto is_causal = inputs.at(4).as(); + + // Temporary handling of DID parallelization see + // https://github.com/NVIDIA/Fuser/issues/2563 + bool handle_device_dim = false; + if (query.dim() == 5) { + handle_device_dim = true; + + NVF_CHECK(key.dim() == 5 && value.dim() == 5); + + auto query_domain = + TensorDomain::noReductions(this->query()->getLogicalDomain()); + auto key_domain = + TensorDomain::noReductions(this->key()->getLogicalDomain()); + auto value_domain = + TensorDomain::noReductions(this->value()->getLogicalDomain()); + NVF_CHECK( + query_domain.front()->isDeviceDim(), + "Only support DID parallelization on outermost axis"); + NVF_CHECK( + key_domain.front()->isDeviceDim(), + "Only support DID parallelization on outermost axis"); + NVF_CHECK( + value_domain.front()->isDeviceDim(), + "Only support DID parallelization on outermost axis"); + + query = query.squeeze(0); + key = key.squeeze(0); + value = value.squeeze(0); + } + + // Flash attention requires the last dimension to be padded to 8. + // https://github.com/pytorch/pytorch/blob/c27882ffa8c1c7e4cf8ebc6c2f879e5b6c8814ad/aten/src/ATen/native/transformers/attention.cpp#L675-L677 + const auto last_dim_size = query.size(-1); + auto pad_last_dim = [last_dim_size]( + at::Tensor inp, int alignment_size) -> at::Tensor { + if (last_dim_size % alignment_size == 0) { + return inp; + } + auto pad_count = alignment_size - (last_dim_size % alignment_size); + auto padded_inp = at::pad(inp, {0, pad_count}); + return padded_inp; + }; + + query = pad_last_dim(query, 8); + key = pad_last_dim(key, 8); + value = pad_last_dim(value, 8); + + // Conmpute scale using original size of last dimension + double scale = inputs.size() > 5 ? inputs.back().as() + : 1.0 / std::sqrt(last_dim_size); + + // ATen reference: + // https://github.com/pytorch/pytorch/blob/c27882ffa8c1c7e4cf8ebc6c2f879e5b6c8814ad/aten/src/ATen/native/transformers/attention.cpp#L680-L681 + auto + [output, + log_sumexp, + cum_seq_q, + cum_seq_k, + query_seq_len, + key_seq_len, + philox_seed, + philox_offset, + debug_attn_mask] = + at::_scaled_dot_product_flash_attention( + query, + key, + value, + dropout_p, + is_causal, + /*return_debug_mask=*/false, + scale); + + // If the inputs were padded, slice the output to restore the original + // size + if (output.size(-1) != last_dim_size) { + output = output.slice(-1, 0, last_dim_size); + } + + // Add back the device dim axis for output. + if (handle_device_dim) { + output = output.unsqueeze(0); + log_sumexp = log_sumexp.unsqueeze(0); + } + + // We ignore cum_seq_q/k outputs since they are undefined tensors for + // non-nested tensors. We do not store query/key_seq_len since they can be + // computed in non-nested tensor directly. debug_attn_mask is ignored + // since `return_debug_mask=false`. + return {output, log_sumexp, philox_seed, philox_offset}; +} + std::string Scope::toString(int indent_size) const { std::stringstream ss; for (auto expr : exprs()) { @@ -4607,6 +5327,94 @@ std::string SdpaBwdOp::toInlineString(int indent_size) const { NVF_CHECK(false, "Tensor op can not be printed inline"); } +std::vector SdpaBwdOp::evaluate( + const ExpressionEvaluator& ee, + const std::vector& inputs) const { + // Backward tensor inputs: grad_input, query, key, value, output, + // logsumexp, max_q/k Temporary handling of DID parallelization. See + // https://github.com/NVIDIA/Fuser/issues/2563 + bool first_dim_is_did = this->key()->as()->axis(0)->isDeviceDim(); + auto out_grad = inputs[0].as(); + if (first_dim_is_did) { + NVF_CHECK(out_grad.dim() == 5, "Expected 5D but found ", out_grad.sizes()); + } else { + NVF_CHECK(out_grad.dim() == 4, "Expected 4D but found ", out_grad.sizes()); + } + + std::vector bwd_inputs; + for (auto idx : c10::irange(6)) { + auto in_tensor = inputs.at(idx).as(); + // Removing the size 1 from sharded axis from tensors. + if (first_dim_is_did) { + in_tensor = in_tensor.squeeze(0); + } + bwd_inputs.push_back(in_tensor); + } + const auto dropout_p = inputs.at(6).as(); + const auto is_causal = inputs.at(7).as(); + const auto philox_seed = inputs.at(8).as(); + const auto philox_offset = inputs.at(9).as(); + + // Flash attention requires the last dimension to be padded to 8. + // https://github.com/pytorch/pytorch/blob/c27882ffa8c1c7e4cf8ebc6c2f879e5b6c8814ad/aten/src/ATen/native/transformers/attention.cpp#L675-L677 + const auto last_dim_size = bwd_inputs[0].size(-1); + auto pad_last_dim = [last_dim_size]( + at::Tensor inp, int alignment_size) -> at::Tensor { + if (last_dim_size % alignment_size == 0) { + return inp; + } + auto pad_count = alignment_size - (last_dim_size % alignment_size); + auto padded_inp = at::pad(inp, {0, pad_count}); + return padded_inp; + }; + + // Conmpute scale using original size of last dimension + double scale = inputs.size() > 10 ? inputs.back().as() + : 1.0 / std::sqrt(last_dim_size); + + // ATen reference: + // https://github.com/pytorch/pytorch/blob/c27882ffa8c1c7e4cf8ebc6c2f879e5b6c8814ad/aten/src/ATen/native/transformers/attention.cpp#L680-L681 + // cum_seq_q/k are undefined tensors for non-nested input tensors. + auto [grad_query, grad_key, grad_value] = + at::_scaled_dot_product_flash_attention_backward( + /*grad_output=*/pad_last_dim(bwd_inputs[0], 8), + /*query=*/pad_last_dim(bwd_inputs[1], 8), + /*key=*/pad_last_dim(bwd_inputs[2], 8), + /*value=*/pad_last_dim(bwd_inputs[3], 8), + /*output=*/pad_last_dim(bwd_inputs[4], 8), + /*logsumexp=*/bwd_inputs[5], + /*cum_seq_q=*/at::Tensor(), + /*cum_seq_k=*/at::Tensor(), + // Note: ATen implementation expects max_q/max_k as scalars. + /*max_q=*/bwd_inputs[1].size(2), + /*max_k=*/bwd_inputs[2].size(2), + /*dropout_p=*/dropout_p, + /*is_causal=*/is_causal, + /*philox_seed=*/philox_seed, + /*philox_offset=*/philox_offset, + /*scale=*/scale); + + // If the inputs were padded, slice the gradsto restore the original size + auto slice_last_dim = [last_dim_size](at::Tensor output) -> at::Tensor { + if (output.size(-1) != last_dim_size) { + return output; + } + return output.slice(-1, 0, last_dim_size); + }; + + // Add device dimension back to outputs. + if (first_dim_is_did) { + grad_query = grad_query.unsqueeze(0); + grad_key = grad_key.unsqueeze(0); + grad_value = grad_value.unsqueeze(0); + } + + return { + slice_last_dim(grad_query), + slice_last_dim(grad_key), + slice_last_dim(grad_value)}; +} + EmbeddingFwdOp::EmbeddingFwdOp( IrBuilderPasskey passkey, TensorView* output, @@ -4668,4 +5476,33 @@ std::string EmbeddingFwdOp::toInlineString(int indent_size) const { NVF_CHECK(false, "Tensor op can not be printed inline"); } +std::vector EmbeddingFwdOp::evaluate( + const ExpressionEvaluator& ee, + const std::vector& inputs) const { + auto input = inputs.at(0).as(); + auto weight = inputs.at(1).as(); + auto norm_type = inputs.at(2).as(); + auto scale_grad_by_freq = inputs.at(3).as(); + auto sparse = inputs.at(4).as(); + std::optional padding_idx = std::nullopt; + if (has_padding_idx()) { + padding_idx = inputs.at(5).as(); + } + std::optional max_norm = std::nullopt; + if (has_max_norm()) { + auto idx = 5 + has_padding_idx(); + max_norm = inputs.at(idx).as(); + } + + namespace F = torch::nn::functional; + return {F::embedding( + input, + weight, + F::EmbeddingFuncOptions() + .padding_idx(padding_idx) + .max_norm(max_norm) + .norm_type(norm_type) + .scale_grad_by_freq(scale_grad_by_freq) + .sparse(sparse))}; +} } // namespace nvfuser diff --git a/csrc/runtime/executor_dispatch.cpp b/csrc/runtime/executor_dispatch.cpp index 5a56b43e9a1..8af0174361a 100644 --- a/csrc/runtime/executor_dispatch.cpp +++ b/csrc/runtime/executor_dispatch.cpp @@ -10,7 +10,6 @@ #include #include -#include #include diff --git a/csrc/runtime/executor_utils.cpp b/csrc/runtime/executor_utils.cpp index 0e941a67afd..6e0c2d769a5 100644 --- a/csrc/runtime/executor_utils.cpp +++ b/csrc/runtime/executor_utils.cpp @@ -564,17 +564,17 @@ void validateVectorizedTensors( ExpressionEvaluator bindInputs( const KernelArgumentHolder& args, - Fusion* fusion) { + Fusion* kernel) { FUSER_PERF_SCOPE("executor_utils::bindInputs"); // args may contains more than just inputs, but inputs are always at the // beginning. NVF_ERROR( - fusion->inputs().size() <= args.size(), - "KernelArgumentHolder contains less argument than fusion's input."); + kernel->inputs().size() <= args.size(), + "KernelArgumentHolder contains less argument than kernel's input."); ExpressionEvaluator expr_eval; - const auto& inputs = fusion->inputs(); + const auto& inputs = kernel->inputs(); for (const auto i : c10::irange(inputs.size())) { // NOTE: we bind all inputs here, including at::Tensors. This means that // expr_eval will create a PolymorphicValue containing *args[i], which means diff --git a/tests/cpp/test_alias.cpp b/tests/cpp/test_alias.cpp index 14024c5aa40..e171f4cb4a3 100644 --- a/tests/cpp/test_alias.cpp +++ b/tests/cpp/test_alias.cpp @@ -18,7 +18,6 @@ #include #include #include -#include #include #include #include diff --git a/tests/cpp/test_evaluator.cpp b/tests/cpp/test_evaluator.cpp index d9df89c34be..fb088613a97 100644 --- a/tests/cpp/test_evaluator.cpp +++ b/tests/cpp/test_evaluator.cpp @@ -15,7 +15,6 @@ #include #include #include -#include namespace nvfuser { @@ -595,19 +594,15 @@ TEST_F(ExprEvalTest, ReshapePermuteReshape) { out = reshape(out, {IrBuilder::create(6), size(out, 2)}); fusion.addOutput(out); - fusion.aliasOutputToInput(out, in, AllocationType::Evaluate); at::Tensor in_tensor = at::rand({72}).cuda().as_strided({9, 6}, {8, 1}); - ExprEvalExecutor eee; - eee.compile(&fusion); - auto args = KernelArgumentHolder::createKernelArgumentHolder({in_tensor}); - auto outs = eee.run(args); - for (auto i : c10::irange(99)) { - (void)i; - eee.run(args); - } - EXPECT_EQ(in_tensor.data_ptr(), outs[0].data_ptr()); - EXPECT_THAT(outs[0].sizes(), ElementsAre(6, 9)); - EXPECT_THAT(outs[0].strides(), ElementsAre(1, 8)); + + ExpressionEvaluator evaluator; + evaluator.bind(in, in_tensor); + at::Tensor out_tensor = evaluator.evaluate(out).as(); + + EXPECT_EQ(in_tensor.data_ptr(), out_tensor.data_ptr()); + EXPECT_THAT(out_tensor.sizes(), ElementsAre(6, 9)); + EXPECT_THAT(out_tensor.strides(), ElementsAre(1, 8)); } TEST_F(ExprEvalTest, Reshape_ForwardBroadcast) { From 3622f2022eb324e0ce2d2be0038acce366517986 Mon Sep 17 00:00:00 2001 From: Christian Sarofeen Date: Sun, 2 Mar 2025 13:50:33 -0800 Subject: [PATCH 16/16] Revert changes from expr eval executor prototype. --- CMakeLists.txt | 2 - csrc/ir/evaluate.cpp | 1135 ------------------------------- csrc/runtime/expr_eval_exec.cpp | 353 ---------- csrc/runtime/expr_eval_exec.h | 119 ---- 4 files changed, 1609 deletions(-) delete mode 100644 csrc/ir/evaluate.cpp delete mode 100644 csrc/runtime/expr_eval_exec.cpp delete mode 100644 csrc/runtime/expr_eval_exec.h diff --git a/CMakeLists.txt b/CMakeLists.txt index 63f09fc738a..8157e6fdcb9 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -161,7 +161,6 @@ list(APPEND NVFUSER_SRCS ${NVFUSER_SRCS_DIR}/ir/builder.cpp ${NVFUSER_SRCS_DIR}/ir/cloner.cpp ${NVFUSER_SRCS_DIR}/ir/container.cpp - ${NVFUSER_SRCS_DIR}/ir/evaluate.cpp ${NVFUSER_SRCS_DIR}/ir/graphviz.cpp ${NVFUSER_SRCS_DIR}/ir/iostream.cpp ${NVFUSER_SRCS_DIR}/ir/nodes.cpp @@ -217,7 +216,6 @@ list(APPEND NVFUSER_SRCS ${NVFUSER_SRCS_DIR}/runtime/executor_kernel_arg.cpp ${NVFUSER_SRCS_DIR}/runtime/executor_params.cpp ${NVFUSER_SRCS_DIR}/runtime/executor_utils.cpp - ${NVFUSER_SRCS_DIR}/runtime/expr_eval_exec.cpp ${NVFUSER_SRCS_DIR}/runtime/fusion_cache_utils.cpp ${NVFUSER_SRCS_DIR}/runtime/fusion_executor_cache.cpp ${NVFUSER_SRCS_DIR}/runtime/fusion_kernel_runtime.cpp diff --git a/csrc/ir/evaluate.cpp b/csrc/ir/evaluate.cpp deleted file mode 100644 index ed48c58c8ca..00000000000 --- a/csrc/ir/evaluate.cpp +++ /dev/null @@ -1,1135 +0,0 @@ -// clang-format off -/* - * SPDX-FileCopyrightText: Copyright (c) 2023-present NVIDIA CORPORATION & AFFILIATES. - * All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - */ -// clang-format on - -#include -#include -#include -#include -#include - -namespace nvfuser { - -PolymorphicValue Val::evaluate() { - if (this->value().hasValue()) { - return this->value(); - FUSER_PERF_SCOPE("Val::evaluate"); - } - - ExpressionEvaluator ee; - auto evaluated_val = ee.evaluate(this); - NVF_ERROR( - evaluated_val.hasValue(), - "Detected a const value but failed to infer its value: ", - toInlineString()); - return evaluated_val; -} - -std::vector Expr::evaluate( - const ExpressionEvaluator& ee, - const std::vector& inputs) const { - FUSER_PERF_SCOPE("Expr::evaluate"); - NVF_THROW( - "`evaluate` method for expression ", - getOpString(), - " is not defined. ", - "Please override the evaluate method"); -} - -std::vector Expr::evaluate( - const ExpressionEvaluator& ee, - std::unordered_map& known_values) const { - FUSER_PERF_SCOPE("Expr::evaluate"); - std::vector expr_inputs; - expr_inputs.reserve(inputs().size()); - for (auto inp : inputs()) { - const auto& eval_i = ee.evaluate(inp, known_values); - if (!eval_i.hasValue()) { - return {std::monostate{}}; - } - expr_inputs.emplace_back(eval_i); - } - return this->evaluate(ee, expr_inputs); -} - -void Expr::addDataAttribute(PolymorphicValue attr) { - addAttribute(IrBuilder::createInContainer(container(), std::move(attr))); -} -std::vector FullOp::evaluate( - const ExpressionEvaluator& ee, - const std::vector& inputs) const { - FUSER_PERF_SCOPE("FullOp::evaluate"); - std::vector shape; - for (auto i : c10::irange(inputs.size() - 1)) { - shape.push_back(inputs.at(i).as()); - } - DataType dtype = getFillValue()->getDataType().value(); - const auto options = - at::TensorOptions().device(at::kCUDA).dtype(data_type_to_aten(dtype)); - using namespace PolymorphicValue_functions; - return {at::full(shape, toScalar(inputs.back()), options)}; -} - -std::vector SelectOp::evaluate( - const ExpressionEvaluator& ee, - const std::vector& inputs) const { - FUSER_PERF_SCOPE("SelectOp::evaluate"); - const auto& in = inputs.at(0).as(); - int64_t dimension = dim(); - int64_t index = (int64_t)inputs.at(1); - return {in.select(dimension, index)}; -} - -std::vector IndexSelectOp::evaluate( - const ExpressionEvaluator& ee, - const std::vector& inputs) const { - FUSER_PERF_SCOPE("IndexSelectOp::evaluate"); - const auto& in = inputs.at(0).as(); - int64_t dimension = dim(); - const auto& indices = inputs.at(1).as().squeeze(); - return {at::index_select(in, dimension, indices)}; -} - -std::vector TorchGatherOp::evaluate( - const ExpressionEvaluator& ee, - const std::vector& inputs) const { - FUSER_PERF_SCOPE("TorchGatherOp::evaluate"); - const auto& input = inputs.at(0).as(); - const auto& index = inputs.at(1).as(); - auto dimension = dim(); - if (exactSizes()) { - return {at::take_along_dim(input, index, dimension)}; - } else { - return {at::gather(input, dimension, index)}; - } -} - -std::vector ScatterOp::evaluate( - const ExpressionEvaluator& ee, - const std::vector& inputs) const { - FUSER_PERF_SCOPE("ScatterOp::evaluate"); - const auto& input = inputs.at(0).as(); - const auto& index = inputs.at(1).as(); - const auto& src = inputs.at(2).as(); - auto dimension = dim(); - return {at::scatter(input, dimension, index, src)}; -} - -std::vector IotaOp::evaluate( - const ExpressionEvaluator& ee, - const std::vector& inputs) const { - FUSER_PERF_SCOPE("IotaOp::evaluate"); - const auto options = - at::TensorOptions().device(at::kCUDA).dtype(data_type_to_aten(dtype())); - int64_t length = (int64_t)inputs.at(0); - - if (isIntegralType(dtype())) { - int64_t start = (int64_t)inputs.at(1); - int64_t step = (int64_t)inputs.at(2); - int64_t end = start + step * length; - return {at::arange(start, end, step, options)}; - } else if (isFloatingPointType(dtype())) { - double start = (double)inputs.at(1); - double step = (double)inputs.at(2); - // Due to rounding error, it can be hard to guarantee the size of - // the output of arange to be exactly length, so we generate a - // larger tensor and truncate it to length. - double end = start + step * ((double)length + 1); - return {at::arange(start, end, step, options).narrow(0, 0, length)}; - } else { - NVF_THROW("Unsupported dtype in IotaOp evaluator: ", dtype()); - } -} - -std::vector EyeOp::evaluate( - const ExpressionEvaluator& ee, - const std::vector& inputs) const { - FUSER_PERF_SCOPE("EyeOp::evaluate"); - const auto options = - at::TensorOptions().device(at::kCUDA).dtype(data_type_to_aten(dtype())); - int64_t nrows = (int64_t)inputs.at(0); - if (inputs.size() > 1) { - int64_t ncols = (int64_t)inputs.at(1); - return {at::eye(nrows, ncols, options)}; - } else { - return {at::eye(nrows, options)}; - } -} - -std::vector UnaryOp::evaluate( - const ExpressionEvaluator& ee, - const std::vector& inputs) const { - FUSER_PERF_SCOPE("UnaryOp::evaluate"); - using namespace PolymorphicValue_functions; - - const auto& in = inputs.at(0); - if (!in.hasValue()) { - return {std::monostate{}}; - } - - switch (getUnaryOpType()) { - case UnaryOpType::Neg: - return {-in}; - case UnaryOpType::Cast: - if (in.is()) { - return {PolymorphicValue( - in.as().to(data_type_to_aten(out()->dtype())))}; - } else if (isIntegralType(*out()->getDataType())) { - return {PolymorphicValue((int64_t)in)}; - } else if (isFloatingPointType(*out()->getDataType())) { - return {PolymorphicValue((double)in)}; - } else if (out()->getDataType() == DataType::Bool) { - return {PolymorphicValue((bool)in)}; - } else if (isComplexType(*out()->getDataType())) { - return {PolymorphicValue((std::complex)in)}; - } else { - NVF_THROW("dtype not supported in evaluator: ", *out()->getDataType()); - } - case UnaryOpType::Reciprocal: - return {1.0 / in}; - break; - case UnaryOpType::Abs: - return {abs(in)}; - break; - case UnaryOpType::LogicalNot: - return {!in}; - break; - case UnaryOpType::BitwiseNot: - return {~in}; - break; - case UnaryOpType::Erf: - return {erf(in)}; - break; - case UnaryOpType::ToUnsignedSmemAddr: - return {(int64_t)(unsigned)in}; - break; - case UnaryOpType::AdjustPartialLdMatrixAddrInTuring8: - case UnaryOpType::AdjustPartialLdMatrixAddrInTuring16: - return {in}; - break; - case UnaryOpType::Dereference: - if (*out()->getDataType() == DataType::Float) { - return {PolymorphicValue((double)*(float*)in)}; - } else { - NVF_THROW("dtype not supported in evaluator: ", *out()->getDataType()); - } - break; - case UnaryOpType::Sigmoid: - return {in.as().sigmoid()}; - break; - case UnaryOpType::Tanh: - return {in.as().tanh()}; - break; - case UnaryOpType::Relu: - return {at::relu(in.as())}; - break; - case UnaryOpType::Gelu: - return {at::gelu(in.as())}; - break; - case UnaryOpType::Exp: - return {at::exp(in.as())}; - break; - case UnaryOpType::Sin: - return {in.as().sin()}; - break; - case UnaryOpType::Signbit: - return {signbit(in)}; - break; - case UnaryOpType::Cos: - return {in.as().cos()}; - break; - case UnaryOpType::BitCast: - NVF_CHECK( - dataTypeSize(input(0)->dtype()) == dataTypeSize(out()->dtype()), - "BitCast only works for types of the same size"); - if (isComplexType(input(0)->dtype()) && - std::holds_alternative(out()->dtype().type)) { - // view_as_real case. - auto vec_type = std::get(out()->dtype().type); - auto inp_scalar_type = getTypeFromComplexType(input(0)->dtype()); - NVF_CHECK( - *vec_type.type == inp_scalar_type, - "Output type must be the same as the scalar type of the complex input."); - NVF_CHECK( - vec_type.size == 2, - "Expected output to be array of size 2, found array of size ", - vec_type.size); - return {in.as()}; - } else { - return {in.as().view(data_type_to_aten(out()->dtype()))}; - } - break; - case UnaryOpType::Rsqrt: - return {in.as().rsqrt()}; - break; - case UnaryOpType::Real: - return {at::real(in.as())}; - break; - case UnaryOpType::Imag: - return {at::imag(in.as())}; - break; - case UnaryOpType::Tan: - return {in.as().tan()}; - break; - case UnaryOpType::IsFinite: - return {at::isfinite(in.as())}; - break; - default: - NVF_CHECK( - false, - "Unexpected operator type ", - getUnaryOpType(), - " in ", - toString()); - } -} - -std::vector BinaryOp::evaluate( - const ExpressionEvaluator& ee, - const std::vector& inputs) const { - FUSER_PERF_SCOPE("BinaryOp::evaluate"); - using namespace PolymorphicValue_functions; - const auto& lhs = inputs.at(0); - const auto& rhs = inputs.at(1); - - switch (getBinaryOpType()) { - case BinaryOpType::Add: - return {lhs + rhs}; - break; - case BinaryOpType::Sub: - return {lhs - rhs}; - break; - case BinaryOpType::Mul: - return {lhs * rhs}; - break; - case BinaryOpType::Div: - return {lhs / rhs}; - break; - case BinaryOpType::Mod: - NVF_CHECK(rhs != 0); - return {lhs % rhs}; - break; - case BinaryOpType::Fmod: - NVF_CHECK(rhs != 0); - return {fmod(lhs, rhs)}; - break; - case BinaryOpType::CeilDiv: - NVF_CHECK(rhs != 0); - return {ceildiv(lhs, rhs)}; - break; - case BinaryOpType::LogicalAnd: - return {lhs && rhs}; - break; - case BinaryOpType::LogicalOr: - return {lhs || rhs}; - break; - case BinaryOpType::BitwiseAnd: - return {lhs & rhs}; - break; - case BinaryOpType::BitwiseOr: - return {lhs | rhs}; - break; - case BinaryOpType::BitwiseXor: - return {lhs ^ rhs}; - break; - case BinaryOpType::Eq: - return {eq(lhs, rhs)}; - break; - case BinaryOpType::NE: - return {ne(lhs, rhs)}; - break; - case BinaryOpType::GT: - return {gt(lhs, rhs)}; - break; - case BinaryOpType::GE: - return {ge(lhs, rhs)}; - break; - case BinaryOpType::LT: - return {lt(lhs, rhs)}; - break; - case BinaryOpType::LE: - return {le(lhs, rhs)}; - break; - case BinaryOpType::Max: - return {max(lhs, rhs)}; - break; - case BinaryOpType::Min: - return {min(lhs, rhs)}; - break; - case BinaryOpType::Gcd: - return {gcd(lhs, rhs)}; - break; - case BinaryOpType::Lshift: - return {lhs << rhs}; - break; - case BinaryOpType::Rshift: - return {lhs >> rhs}; - break; - case BinaryOpType::Complex: - return {at::complex(lhs.as(), rhs.as())}; - break; - case BinaryOpType::Pow: - return {pow(lhs, rhs)}; - break; - default: - NVF_CHECK( - false, - "Unexpected operator type: ", - getBinaryOpType(), - " in ", - toString()); - } -} - -std::vector TernaryOp::evaluate( - const ExpressionEvaluator& ee, - const std::vector& inputs) const { - FUSER_PERF_SCOPE("TernaryOp::evaluate"); - using namespace PolymorphicValue_functions; - const auto& a = inputs.at(0); - const auto& b = inputs.at(1); - const auto& c = inputs.at(2); - switch (getTernaryOpType()) { - case TernaryOpType::Clamp: - return {std::min(std::max(a, b), c)}; - break; - case TernaryOpType::Lerp: - // This is the same lerp computed in helpers.cu - // https://math.stackexchange.com/a/1798323 - return {(c < 0.5) ? a + c * (b - a) : b - (b - a) * (1.0 - c)}; - break; - case TernaryOpType::Threshold: - return {(a <= b) ? c : a}; - break; - case TernaryOpType::Where: - return {a.as() ? b : c}; - break; - default: - NVF_CHECK( - false, - "Unexpected operator type: ", - getTernaryOpType(), - " in ", - toString()); - } -} - -std::vector ArrayConstruct::evaluate( - const ExpressionEvaluator& ee, - const std::vector& inputs) const { - FUSER_PERF_SCOPE("ArrayConstruct::evaluate"); - return {PolymorphicValue(inputs)}; -} - -std::vector ReverseArray::evaluate( - const ExpressionEvaluator& ee, - const std::vector& inputs) const { - FUSER_PERF_SCOPE("ReverseArray::evaluate"); - NVF_ERROR(inputs.size() == 1, "ReverseArray expects 1 input"); - PolymorphicValue array = inputs.at(0); - auto& vec = array.as(); - std::reverse(vec.begin(), vec.end()); - return {std::move(array)}; -} - -std::vector GetItem::evaluate( - const ExpressionEvaluator& ee, - const std::vector& inputs) const { - FUSER_PERF_SCOPE("GetItem::evaluate"); - NVF_ERROR(inputs.size() == 2, "GetItem expects 2 inputs"); - return {PolymorphicValue(inputs.at(0)[inputs.at(1)])}; -} - -std::vector StructConstruct::evaluate( - const ExpressionEvaluator& ee, - const std::vector& inputs) const { - FUSER_PERF_SCOPE("StructConstruct::evaluate"); - NVF_ERROR( - this->inputs().size() == inputs.size(), - "StructConstruct expects ", - this->inputs().size(), - " inputs"); - PolymorphicValue struct_ = - std::get(output(0)->dtype().type).create(); - for (int64_t i : c10::irange((int64_t)inputs.size())) { - struct_->*attribute(i) = inputs.at(i); - } - return {std::move(struct_)}; -} - -std::vector GetAttr::evaluate( - const ExpressionEvaluator& ee, - const std::vector& inputs) const { - FUSER_PERF_SCOPE("GetAttr::evaluate"); - NVF_ERROR(inputs.size() == 1, "GetAttr expects 1 input"); - return {inputs.at(0)->*attr()}; -} - -std::vector TensorConstruct::evaluate( - const ExpressionEvaluator& ee, - const std::vector& inputs) const { - FUSER_PERF_SCOPE("TensorConstruct::evaluate"); - NVF_ERROR(inputs.size() == 1, "TensorConstruct expects 1 input"); - using namespace PolymorphicValue_functions; - return {toTensor(inputs.at(0))}; -} - -std::vector BroadcastOp::evaluate( - const ExpressionEvaluator& ee, - const std::vector& inputs) const { - FUSER_PERF_SCOPE("BroadcastOp::evaluate"); - NVF_ERROR( - inputs.size() == 1, - "BroadcastOp expects exactly 1 input, but received ", - inputs.size()); - std::vector out_shape; - const auto& in = inputs.at(0).as(); - int64_t idx = 0; - for (bool b : getBroadcastDimFlags()) { - if (b) { - out_shape.push_back(1); - } else { - out_shape.push_back(in.sizes()[idx++]); - } - } - return {in.view(out_shape)}; -} - -std::vector SqueezeOp::evaluate( - const ExpressionEvaluator& ee, - const std::vector& inputs) const { - FUSER_PERF_SCOPE("SqueezeOp::evaluate"); - NVF_ERROR( - inputs.size() == 1, - "SqueezeOp expects exactly 1 input, but received ", - inputs.size()); - std::vector out_shape; - const auto& in = inputs.at(0).as(); - const auto& is_squeeze_dims = getSqueezeDimFlags(); - NVF_ERROR( - (int64_t)is_squeeze_dims.size() == in.dim(), - "The dimensions of input tensor and does not match with is_squeeze_dims"); - at::Tensor out = in; - for (int64_t i : c10::irange((int64_t)is_squeeze_dims.size())) { - if (is_squeeze_dims[i]) { - if (in.stride(i) == 0) { - // If the input dimension is expanded in this dimension, undo the expand - // by slicing. This ensures that any broadcast dimensions will be - // unexpanded when we do the final call to view() - out = out.slice(i, 0, 1); - } - } else { - out_shape.push_back(in.sizes()[i]); - } - } - return {out.view(out_shape)}; -} - -std::vector ReductionOp::evaluate( - const ExpressionEvaluator& ee, - const std::vector& inputs) const { - FUSER_PERF_SCOPE("ReductionOp::evaluate"); - const auto& input = inputs.at(0).as(); - const auto output = out()->as(); - - NVF_ERROR( - !output->hasRoot(), - "Evaluation for rFactored reductions is not supported."); - - std::vector reduction_axes; - for (const auto i : c10::irange(int64_t(output->getLogicalDomain().size()))) { - auto ax = output->getLogicalDomain().at(i); - if (ax->isReduction()) { - reduction_axes.push_back(i); - } - } - switch (getReductionOpType()) { - case BinaryOpType::Add: - return {at::sum(input, reduction_axes)}; - break; - case BinaryOpType::Max: - return {at::amax(input, reduction_axes)}; - break; - case BinaryOpType::Min: - return {at::amin(input, reduction_axes)}; - break; - default: - NVF_CHECK( - false, - "Unexpected operator type: ", - getReductionOpType(), - " in ", - toString()); - } -} - -std::vector GroupedReductionOp::evaluate( - const ExpressionEvaluator& ee, - const std::vector& inputs) const { - FUSER_PERF_SCOPE("GroupedReductionOp::evaluate"); - const auto num_reductions = numHorizontallyGroupedExprs(); - std::vector grouped_reduction_out; - grouped_reduction_out.reserve(num_reductions); - for (const auto i : c10::irange(num_reductions)) { - const auto& in_tensor = inputs.at(i).as(); - const auto out_tv = output(i)->as(); - NVF_ERROR( - !out_tv->hasRoot(), - "Evaluation for rFactored reductions is not supported."); - - std::vector reduction_axes; - for (const auto id : - c10::irange(int64_t(out_tv->getLogicalDomain().size()))) { - auto ax = out_tv->getLogicalDomain().at(id); - if (ax->isReduction()) { - reduction_axes.push_back(id); - } - } - switch (getReductionOpType(i)) { - case BinaryOpType::Add: - grouped_reduction_out.emplace_back(at::sum(in_tensor, reduction_axes)); - break; - case BinaryOpType::Max: - grouped_reduction_out.emplace_back(at::amax(in_tensor, reduction_axes)); - break; - default: - NVF_CHECK( - false, - "Unexpected operator type: ", - getReductionOpType(i), - " in ", - toString()); - } - } - return grouped_reduction_out; -} - -std::vector WelfordOp::evaluate( - const ExpressionEvaluator& ee, - const std::vector& inputs) const { - FUSER_PERF_SCOPE("WelfordOp::evaluate"); - NVF_ERROR( - !hasInit(), - "Evaluation for WelfordOp is not implemented for non-empty initial values."); - const auto& in_tensor = inputs.at(0).as(); - const auto out_tv = out()->as(); - NVF_ERROR( - !out_tv->hasRoot(), - "Evaluation for WelfordOp is not supported when output is rFactored."); - - int64_t N = 1; - std::vector reduction_axes; - for (const auto i : c10::irange(int64_t(out_tv->getLogicalDomain().size()))) { - auto ax = out_tv->getLogicalDomain().at(i); - if (ax->isReduction()) { - reduction_axes.push_back(i); - N *= in_tensor.size(i); - } - } - const auto [in_var, in_avg] = - at::var_mean(in_tensor, reduction_axes, false, false); - return {in_avg, in_var * N, N}; -} - -std::vector ExpandOp::evaluate( - const ExpressionEvaluator& ee, - const std::vector& inputs) const { - FUSER_PERF_SCOPE("ExpandOp::evaluate"); - const auto& in = inputs.at(0).as(); - std::vector expanded_size; - for (auto i : c10::irange(1, inputs.size())) { - expanded_size.push_back((int64_t)inputs.at(i)); - } - return {in.expand(expanded_size)}; -} - -std::vector RepeatOp::evaluate( - const ExpressionEvaluator& ee, - const std::vector& inputs) const { - FUSER_PERF_SCOPE("RepeatOp::evaluate"); - NVF_ERROR( - inputs.size() == 1, - "RepeatOp expects exactly 1 input, but received ", - inputs.size()); - auto tensor = inputs.at(0).as(); - std::vector multipliers; - multipliers.reserve(out()->getLogicalDomain().size()); - const auto c2p = - PairwiseLogicalDomainMap(in(), out()).mapConsumerToProducer(); - for (const auto i : c10::irange(out()->getLogicalDomain().size())) { - auto out_id = out()->getLogicalDomain().at(i); - auto inp_id = c2p.at(out_id); - auto out_extent = ee.evaluate(out_id->extent()).as(); - auto inp_extent = ee.evaluate(inp_id->extent()).as(); - NVF_ERROR( - out_extent % inp_extent == 0, - "For dimension ", - i, - ", the output extent (", - out_extent, - " should be a multiple of the input extent (", - inp_extent, - ")."); - multipliers.push_back(out_extent / inp_extent); - } - return {tensor.repeat(multipliers)}; -} - -std::vector ViewAsScalar::evaluate( - const ExpressionEvaluator& ee, - const std::vector& inputs) const { - FUSER_PERF_SCOPE("ViewAsScalar::evaluate"); - const at::Tensor& in = inputs.at(0).as(); - return {at::view_as_real(in)}; -} - -std::vector ViewOp::evaluate( - const ExpressionEvaluator& ee, - const std::vector& inputs) const { - FUSER_PERF_SCOPE("ViewOp::evaluate"); - - NVF_ERROR(inputs.size() == 1); - const at::Tensor& in_tensor = inputs[0].as(); - - const std::vector& out_logical = out()->getLogicalDomain(); - std::vector out_shape; - out_shape.reserve(out_logical.size()); - - int missing_vals = - std::count_if(out_logical.begin(), out_logical.end(), [](IterDomain* id) { - return !id->isDeviceDim() && - !id->getMaybeExpandedExtent()->isConstScalar(); - }); - - for (IterDomain* id : out_logical) { - if (id->isDeviceDim()) { - out_shape.push_back(1); - } else if (id->getMaybeExpandedExtent()->isConstScalar()) { - out_shape.push_back( - id->getMaybeExpandedExtent()->evaluate().as()); - } else { - if (missing_vals == 1) { - out_shape.push_back(-1); - } else { - out_shape.push_back( - ee.evaluate(id->getMaybeExpandedExtent()).as()); - } - } - } - - // TODO: check allocation domain and contiguity. - - // Use `at::Tensor::reshape` instead of `at::Tensor::view` because `ViewOp` - // doesn't always produce an alias. For example, when merging an expanded - // `IterType::Broadcast` and an `IterType::Iteration`, `ViewOp` has to realize - // the expand. - if (in_tensor.is_contiguous()) { - return {in_tensor.view(out_shape)}; - } - return {in_tensor.reshape(out_shape)}; -} - -std::vector LoadStoreOp::evaluate( - const ExpressionEvaluator& ee, - const std::vector& inputs) const { - FUSER_PERF_SCOPE("LoadStoreOp::evaluate"); - - if (TensorView* out_tv = dynamic_cast(out())) { - if (out_tv->hasRoot()) { - std::optional> permutation = - ir_utils::computePermutation( - out_tv->getRootDomain(), out_tv->getLogicalDomain()); - NVF_ERROR( - permutation.has_value(), - "The logical domain of a Set.Permute is supposed to be a permutation of the root domain: ", - out_tv->toString()); - NVF_ERROR(inputs.size() == 1); - at::Tensor in_tensor = inputs[0].as(); - at::Tensor out_tensor = in_tensor.permute(*permutation); - return {out_tensor}; - } - } - return inputs; -} - -std::vector PadOp::evaluate( - const ExpressionEvaluator& ee, - const std::vector& inputs) const { - FUSER_PERF_SCOPE("PadOp::evaluate"); - const auto& in = inputs.at(0).as(); - - std::vector pad_widths; - auto pad_width_offset = getPadWidthInputOffset(); - auto num_dims = in.dim(); - - for (auto i = num_dims - 1; i > -1; i--) { - auto left_pad = (int64_t)inputs.at(pad_width_offset + 2 * i); - auto right_pad = (int64_t)inputs.at(pad_width_offset + 2 * i + 1); - pad_widths.push_back(left_pad); - pad_widths.push_back(right_pad); - } - - if (isComplexType(*out()->getDataType())) { - std::complex value = - static_cast>(inputs.at(1)); - auto real = at::real(in); - auto imag = at::imag(in); - auto padded_real = at::pad(real, pad_widths, "constant", value.real()); - auto padded_imag = at::pad(imag, pad_widths, "constant", value.imag()); - return {at::complex(padded_real, padded_imag)}; - } else { - double value = static_cast(inputs.at(1)); - return {at::pad(in, pad_widths, "constant", value)}; - } -} - -std::vector SliceOp::evaluate( - const ExpressionEvaluator& ee, - const std::vector& inputs) const { - FUSER_PERF_SCOPE("SliceOp::evaluate"); - const auto& in = inputs.at(0).as(); - std::vector ranges; - auto ranges_offset = getRangeInputOffset(); - auto num_dims = in.dim(); - for (const auto i : c10::irange(num_dims)) { - auto start = (int64_t)inputs.at(ranges_offset + 3 * i); - auto stop = (int64_t)inputs.at(ranges_offset + 3 * i + 1); - auto step = (int64_t)inputs.at(ranges_offset + 3 * i + 2); - ranges.emplace_back(at::indexing::Slice(start, stop, step)); - } - return {in.index(ranges)}; -} - -std::vector CatOp::evaluate( - const ExpressionEvaluator& ee, - std::unordered_map& known_values) const { - FUSER_PERF_SCOPE("CatOp::evaluate"); - // CatOp is preceded by a PadOp internally. - // For ATen evaluation, directly compute the unpadded inputs. - std::vector unpadded_inputs; - unpadded_inputs.reserve(inputs().size()); - int64_t concat_dim = concatenatedDim(); - for (Val* inp : inputs()) { - NVF_CHECK( - inp->definition() != nullptr && inp->definition()->isA(), - "Expected CatOp to be preceded by a PadOp."); - auto eval_i = ee.evaluate(inp->definition()->input(0), known_values); - unpadded_inputs.push_back(eval_i.as()); - } - return {at::cat(unpadded_inputs, concat_dim)}; -} - -std::vector MatmulOp::evaluate( - const ExpressionEvaluator& ee, - const std::vector& inputs) const { - FUSER_PERF_SCOPE("MatmulOp::evaluate"); - const auto a = inputs.at(0).as(); - const auto b = inputs.at(1).as(); - - auto matmul_out = at::matmul(a, b); - - // When the contracting dimension is sharded, each device has a partial - // matmul output and is followed by an allreduce. For loop split, this is - // represented as an rfactored reduction. The local matmul logical domain - // after the rfactor is: i{DIDx}, i{M}, i{N}, r{K//d}. Unsqueeze the - // rfactored DID axis to correctly bind with the logical domain. See - // tests/python/test_multidevice.py/test_matmul_allreduce_loop_split - auto out_logical = TensorDomain::noReductions(out()->getLogicalDomain()); - int64_t rfactor_did_idx = -1; - for (auto idx : c10::irange(static_cast(out_logical.size()))) { - if (!out_logical.at(idx)->isRFactorProduct() || - !out_logical.at(idx)->isDeviceDim()) { - continue; - } - if (rfactor_did_idx != -1) { - NVF_THROW( - "Expected only 1 rfactored DID iterdomain, found at least 2 in ", - out_logical); - } - rfactor_did_idx = idx; - } - - if (rfactor_did_idx != -1) { - matmul_out = matmul_out.unsqueeze(rfactor_did_idx); - } - - const auto& [sizes, strides] = inferShapeOfOutput(out(), ee); - auto meta_out = at::detail::empty_strided_meta(sizes, strides, a.dtype()); - - if (meta_out.is_contiguous()) { - return {matmul_out}; - } - - auto strided_matmul_out = at::empty_strided(sizes, strides, a.options()); - strided_matmul_out = strided_matmul_out.copy_(matmul_out); - return {strided_matmul_out}; -} - -std::vector LinearOp::evaluate( - const ExpressionEvaluator& ee, - const std::vector& inputs) const { - FUSER_PERF_SCOPE("LinearOp::evaluate"); - const auto in = inputs.at(0).as(); - auto weight = inputs.at(1).as(); - - auto squeeze_device_dims = [](at::Tensor& t, - int64_t num_device_dims) -> void { - // Record the initial shape for the error message. - std::vector shape = t.sizes().vec(); - for ([[maybe_unused]] auto _ : c10::irange(num_device_dims)) { - NVF_CHECK( - t.size(0) == 1, - "When the weight is >2D, expect its preceding dimensions and " - "the bias's preceding dimensions to " - "be DID-parallel and therefore size-1: ", - shape); - t = t.squeeze(0); - } - }; - - // The squeezes and unsqueezes are currently required to support a sharded - // linear layer. Remove them after #2563. - auto num_device_dims = weight.dim() - 2; - squeeze_device_dims(weight, num_device_dims); - - at::Tensor out; - if (has_bias()) { - auto bias = inputs.at(2).as(); - squeeze_device_dims(bias, num_device_dims); - out = at::linear(in, weight, bias); - } else { - out = at::linear(in, weight); - } - - for ([[maybe_unused]] auto _ : c10::irange(num_device_dims)) { - out = out.unsqueeze(0); - } - return {out}; -} - -std::vector SdpaFwdOp::evaluate( - const ExpressionEvaluator& ee, - const std::vector& inputs) const { - FUSER_PERF_SCOPE("SdpaFwdOp::evaluate"); - auto query = inputs.at(0).as(); - auto key = inputs.at(1).as(); - auto value = inputs.at(2).as(); - - const auto dropout_p = inputs.at(3).as(); - const auto is_causal = inputs.at(4).as(); - - // Temporary handling of DID parallelization see - // https://github.com/NVIDIA/Fuser/issues/2563 - bool handle_device_dim = false; - if (query.dim() == 5) { - handle_device_dim = true; - - NVF_CHECK(key.dim() == 5 && value.dim() == 5); - - auto query_domain = - TensorDomain::noReductions(this->query()->getLogicalDomain()); - auto key_domain = - TensorDomain::noReductions(this->key()->getLogicalDomain()); - auto value_domain = - TensorDomain::noReductions(this->value()->getLogicalDomain()); - NVF_CHECK( - query_domain.front()->isDeviceDim(), - "Only support DID parallelization on outermost axis"); - NVF_CHECK( - key_domain.front()->isDeviceDim(), - "Only support DID parallelization on outermost axis"); - NVF_CHECK( - value_domain.front()->isDeviceDim(), - "Only support DID parallelization on outermost axis"); - - query = query.squeeze(0); - key = key.squeeze(0); - value = value.squeeze(0); - } - - // Flash attention requires the last dimension to be padded to 8. - // https://github.com/pytorch/pytorch/blob/c27882ffa8c1c7e4cf8ebc6c2f879e5b6c8814ad/aten/src/ATen/native/transformers/attention.cpp#L675-L677 - const auto last_dim_size = query.size(-1); - auto pad_last_dim = [last_dim_size]( - at::Tensor inp, int alignment_size) -> at::Tensor { - if (last_dim_size % alignment_size == 0) { - return inp; - } - auto pad_count = alignment_size - (last_dim_size % alignment_size); - auto padded_inp = at::pad(inp, {0, pad_count}); - return padded_inp; - }; - - query = pad_last_dim(query, 8); - key = pad_last_dim(key, 8); - value = pad_last_dim(value, 8); - - // Conmpute scale using original size of last dimension - double scale = inputs.size() > 5 ? inputs.back().as() - : 1.0 / std::sqrt(last_dim_size); - - // ATen reference: - // https://github.com/pytorch/pytorch/blob/c27882ffa8c1c7e4cf8ebc6c2f879e5b6c8814ad/aten/src/ATen/native/transformers/attention.cpp#L680-L681 - auto - [output, - log_sumexp, - cum_seq_q, - cum_seq_k, - query_seq_len, - key_seq_len, - philox_seed, - philox_offset, - debug_attn_mask] = - at::_scaled_dot_product_flash_attention( - query, - key, - value, - dropout_p, - is_causal, - /*return_debug_mask=*/false, - scale); - - // If the inputs were padded, slice the output to restore the original - // size - if (output.size(-1) != last_dim_size) { - output = output.slice(-1, 0, last_dim_size); - } - - // Add back the device dim axis for output. - if (handle_device_dim) { - output = output.unsqueeze(0); - log_sumexp = log_sumexp.unsqueeze(0); - } - - // We ignore cum_seq_q/k outputs since they are undefined tensors for - // non-nested tensors. We do not store query/key_seq_len since they can be - // computed in non-nested tensor directly. debug_attn_mask is ignored - // since `return_debug_mask=false`. - return {output, log_sumexp, philox_seed, philox_offset}; -} - -std::vector SdpaBwdOp::evaluate( - const ExpressionEvaluator& ee, - const std::vector& inputs) const { - FUSER_PERF_SCOPE("SdpaBwdOp::evaluate"); - // Backward tensor inputs: grad_input, query, key, value, output, - // logsumexp, max_q/k Temporary handling of DID parallelization. See - // https://github.com/NVIDIA/Fuser/issues/2563 - bool first_dim_is_did = this->key()->as()->axis(0)->isDeviceDim(); - auto out_grad = inputs[0].as(); - if (first_dim_is_did) { - NVF_CHECK(out_grad.dim() == 5, "Expected 5D but found ", out_grad.sizes()); - } else { - NVF_CHECK(out_grad.dim() == 4, "Expected 4D but found ", out_grad.sizes()); - } - - std::vector bwd_inputs; - for (auto idx : c10::irange(6)) { - auto in_tensor = inputs.at(idx).as(); - // Removing the size 1 from sharded axis from tensors. - if (first_dim_is_did) { - in_tensor = in_tensor.squeeze(0); - } - bwd_inputs.push_back(in_tensor); - } - const auto dropout_p = inputs.at(6).as(); - const auto is_causal = inputs.at(7).as(); - const auto philox_seed = inputs.at(8).as(); - const auto philox_offset = inputs.at(9).as(); - - // Flash attention requires the last dimension to be padded to 8. - // https://github.com/pytorch/pytorch/blob/c27882ffa8c1c7e4cf8ebc6c2f879e5b6c8814ad/aten/src/ATen/native/transformers/attention.cpp#L675-L677 - const auto last_dim_size = bwd_inputs[0].size(-1); - auto pad_last_dim = [last_dim_size]( - at::Tensor inp, int alignment_size) -> at::Tensor { - if (last_dim_size % alignment_size == 0) { - return inp; - } - auto pad_count = alignment_size - (last_dim_size % alignment_size); - auto padded_inp = at::pad(inp, {0, pad_count}); - return padded_inp; - }; - - // Conmpute scale using original size of last dimension - double scale = inputs.size() > 10 ? inputs.back().as() - : 1.0 / std::sqrt(last_dim_size); - - // ATen reference: - // https://github.com/pytorch/pytorch/blob/c27882ffa8c1c7e4cf8ebc6c2f879e5b6c8814ad/aten/src/ATen/native/transformers/attention.cpp#L680-L681 - // cum_seq_q/k are undefined tensors for non-nested input tensors. - auto [grad_query, grad_key, grad_value] = - at::_scaled_dot_product_flash_attention_backward( - /*grad_output=*/pad_last_dim(bwd_inputs[0], 8), - /*query=*/pad_last_dim(bwd_inputs[1], 8), - /*key=*/pad_last_dim(bwd_inputs[2], 8), - /*value=*/pad_last_dim(bwd_inputs[3], 8), - /*output=*/pad_last_dim(bwd_inputs[4], 8), - /*logsumexp=*/bwd_inputs[5], - /*cum_seq_q=*/at::Tensor(), - /*cum_seq_k=*/at::Tensor(), - // Note: ATen implementation expects max_q/max_k as scalars. - /*max_q=*/bwd_inputs[1].size(2), - /*max_k=*/bwd_inputs[2].size(2), - /*dropout_p=*/dropout_p, - /*is_causal=*/is_causal, - /*philox_seed=*/philox_seed, - /*philox_offset=*/philox_offset, - /*scale=*/scale); - - // If the inputs were padded, slice the gradsto restore the original size - auto slice_last_dim = [last_dim_size](at::Tensor output) -> at::Tensor { - if (output.size(-1) != last_dim_size) { - return output; - } - return output.slice(-1, 0, last_dim_size); - }; - - // Add device dimension back to outputs. - if (first_dim_is_did) { - grad_query = grad_query.unsqueeze(0); - grad_key = grad_key.unsqueeze(0); - grad_value = grad_value.unsqueeze(0); - } - - return { - slice_last_dim(grad_query), - slice_last_dim(grad_key), - slice_last_dim(grad_value)}; -} - -std::vector EmbeddingFwdOp::evaluate( - const ExpressionEvaluator& ee, - const std::vector& inputs) const { - FUSER_PERF_SCOPE("EmbeddingFwdOp::evaluate"); - auto input = inputs.at(0).as(); - auto weight = inputs.at(1).as(); - auto norm_type = inputs.at(2).as(); - auto scale_grad_by_freq = inputs.at(3).as(); - auto sparse = inputs.at(4).as(); - std::optional padding_idx = std::nullopt; - if (has_padding_idx()) { - padding_idx = inputs.at(5).as(); - } - std::optional max_norm = std::nullopt; - if (has_max_norm()) { - auto idx = 5 + has_padding_idx(); - max_norm = inputs.at(idx).as(); - } - - namespace F = torch::nn::functional; - return {F::embedding( - input, - weight, - F::EmbeddingFuncOptions() - .padding_idx(padding_idx) - .max_norm(max_norm) - .norm_type(norm_type) - .scale_grad_by_freq(scale_grad_by_freq) - .sparse(sparse))}; -} - -} // namespace nvfuser diff --git a/csrc/runtime/expr_eval_exec.cpp b/csrc/runtime/expr_eval_exec.cpp deleted file mode 100644 index 6c484d753c2..00000000000 --- a/csrc/runtime/expr_eval_exec.cpp +++ /dev/null @@ -1,353 +0,0 @@ -// clang-format off -/* - * SPDX-FileCopyrightText: Copyright (c) 2023-present NVIDIA CORPORATION & AFFILIATES. - * All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - */ -// clang-format on - -#include - -#include -#include -#include -#include -#include - -#include - -namespace nvfuser { - -bool ExprEvalExecutor::supported(Fusion* fusion) { - FUSER_PERF_SCOPE("ExprEvalExecutor::supported"); - return std::all_of( - fusion->outputs().begin(), fusion->outputs().end(), [&fusion](Val* out) { - return fusion->getOutputAlias(out).type == AllocationType::Evaluate; - }); -} - -void ExprEvalExecutor::findAndBindInputTVExtentsFrom( - VectorOfUniqueEntries vals) { - for (auto val : vals) { - if (val->isFusionInput()) { - // Could be an input scalar value that will be bound when we bind inputs. - vals.erase(val); - continue; - } - if (val->isConstInt()) { - // Const scalars don't need to be bound - vals.erase(val); - continue; - } - auto tv_info_it = extent_to_tv_info.find(val); - if (tv_info_it != extent_to_tv_info.end()) { - // val is a TV logical ID, use that - tv_sizes_to_bind.pushBack(tv_info_it->second); - vals.erase(val); - } - } - - auto deps = DependencyCheck::getAllValsBetween( - all_potential_input_scalars.set(), vals.vector()); - VectorOfUniqueEntries unique_deps(deps); - auto inputs = all_potential_input_scalars.computeIntersect(unique_deps); - - for (auto inp : inputs) { - if (inp->isConstInt()) { - // Const scalars don't need to be bound - continue; - } - if (inp->isFusionInput()) { - // Could be an input scalar value that will be bound when we bind inputs. - continue; - } - tv_sizes_to_bind.pushBack(extent_to_tv_info[inp]); - } -} - -void ExprEvalExecutor::compile(Fusion* fusion) { - FUSER_PERF_SCOPE("ExprEvalExecutor::compile"); - if (isProfilerEnabled()) { - FusionProfiler::segment(group_id_).startCompile(); - } - NVF_ERROR( - supported(fusion), - "ExprEvalExecutor does not support the Fusion provided."); - fusion_ = std::make_unique(*fusion); - - auto extent_simplification_map = getSimplificationMap(fusion_.get()); - auto mutation_map = - ir_utils::replaceValue(fusion_.get(), extent_simplification_map); - - // Build extent to input tv info map - for (auto inp_id : c10::irange(fusion_->inputs().size())) { - if (TensorView* tv = dynamic_cast(fusion_->inputs()[inp_id])) { - auto domain = TensorDomain::noReductions(tv->getLogicalDomain()); - for (auto id_i : c10::irange(domain.size())) { - auto extent = domain[id_i]->getMaybeExpandedExtent(); - extent_to_tv_info[extent] = {tv, inp_id, id_i}; - all_potential_input_scalars.pushBack( - domain[id_i]->getMaybeExpandedExtent()); - } - } - if (fusion_->inputs()[inp_id]->isIntegralScalar()) { - all_potential_input_scalars.pushBack(fusion_->inputs()[inp_id]); - } - } - - exprs_ = fusion_->exprs(); - for (auto expr : exprs_) { - if (expr->isA()) { - compile(expr->as()); - } else if (expr->isA()) { - compile(expr->as()); - } - // TODO: support RepeatOp and GetMetaData - NVF_ERROR( - !expr->isA() && !expr->isA(), - "Repeat op and MetaDataOp not implemented yet, found: ", - expr->toString()); - for (auto expr_inp : expr->inputs()) { - if (expr_inp->isIntegralScalar()) { - needed_integer_scalars.pushBack(expr_inp); - } - } - } - - findAndBindInputTVExtentsFrom(needed_integer_scalars); - - if (isProfilerEnabled()) { - FusionProfiler::segment(group_id_).stopCompile(); - } -} - -bool ExprEvalExecutor::isCompiled() const { - return fusion_ != nullptr; -} - -std::vector ExprEvalExecutor::run( - KernelArgumentHolder& args, - std::vector outputs) { - FUSER_PERF_SCOPE("ExprEvalExecutor::run"); - - NVF_ERROR( - outputs.empty(), - "Fusion executor is using expression evaluator,", - " and expects that the outputs are not populated, which they were."); - - if (isProfilerEnabled()) { - NVF_CHECK( - group_id_ >= 0, - "An invalid segment id is passed to FusionProfiler!:", - group_id_); - SegmentProfiler& sprof = FusionProfiler::segment(group_id_); - sprof.inputBytesAccessed(computeBytes(args)); - sprof.scheduler(toString(SchedulerType::ExprEval)); - sprof.startKernel(); - } - - NVF_ERROR(fusion_, "Need to compile before you can run."); - // Bind fusion inputs - ExpressionEvaluator expr_eval; - { - FUSER_PERF_SCOPE("ExprEvalExecutor::bindInputs"); - // expr_eval = executor_utils::bindInputs(args, fusion_.get()); - NVF_ERROR( - fusion_->inputs().size() <= args.size(), - "KernelArgumentHolder contains less argument than fusion's input."); - for (auto inp_i : c10::irange(fusion_->inputs().size())) { - expr_eval.unsafeBind(fusion_->inputs()[inp_i], *args[inp_i]); - } - - for (auto tv_info : tv_sizes_to_bind) { - NVF_ERROR( - tv_info.fusion_input_pos < fusion_->inputs().size(), - "Error processing tv_info, asked for fusion input ", - tv_info.fusion_input_pos, - " but fusion only has ", - fusion_->inputs().size(), - " inputs"); - - Val* fusion_input = fusion_->inputs()[tv_info.fusion_input_pos]; - - NVF_ERROR( - fusion_input->isA(), - "Expected provided input to be a tensor view but found ", - fusion_input->toString()); - - auto tv = fusion_input->as(); - - NVF_ERROR( - tv == tv_info.tv, - "Expected fusion input[", - tv_info.fusion_input_pos, - "] to be ", - tv_info.tv->toString(), - " but found ", - tv->toString()); - - auto logical_domain = TensorDomain::noReductions(tv->getLogicalDomain()); - - NVF_ERROR( - tv_info.logical_dim_pos < logical_domain.size(), - "Expected tensor view, ", - tv->toString(), - ", to have a logical domain of size at least ", - tv_info.logical_dim_pos, - " but only found ", - logical_domain.size(), - " dimensions."); - expr_eval.unsafeBind( - logical_domain[tv_info.logical_dim_pos]->getMaybeExpandedExtent(), - (*args[tv_info.fusion_input_pos]) - .as() - .sizes()[tv_info.logical_dim_pos]); - } - } - { - FUSER_PERF_SCOPE("ExprEvalExecutor::Eval"); - for (auto expr : exprs_) { - if (ViewOp* view = dynamic_cast(expr)) { - auto output_tensor = run(view, expr_eval); - expr_eval.unsafeBind(view->out(), output_tensor); - continue; - } else if (LoadStoreOp* ld_st_op = dynamic_cast(expr)) { - auto output_tensor = - run(ld_st_op, expr_eval.evaluate(ld_st_op->in()).as()); - expr_eval.unsafeBind(ld_st_op->out(), output_tensor); - continue; - } - auto infer_val = expr_eval.evaluate(expr->outputs()[0]); - } - - for (const auto& out_val : fusion_->outputs()) { - auto out_tensor = expr_eval.evaluate(out_val).as(); - outputs.emplace_back(out_tensor); - } - } - if (isProfilerEnabled()) { - FusionProfiler::segment(group_id_).stopKernel(); - FusionProfiler::segment(group_id_).setDevice(args.getDeviceIndex()); - } - return outputs; -} - -namespace { -bool isContiguous(TensorView* tv) { - auto logical = TensorDomain::noReductions(tv->getLogicalDomain()); - auto alloc = TensorDomain::noReductions(tv->getMaybeAllocationDomain()); - if (logical.size() != alloc.size()) { - return false; - } - for (int64_t id_i : c10::irange(logical.size())) { - if (logical[id_i]->isBroadcast() && alloc[id_i]->isBroadcast()) { - if (logical[id_i]->hasExpandedExtent()) { - return false; - } - continue; - } - if (logical[id_i] != alloc[id_i]) { - return false; - } - if (!tv->getContiguity()[id_i]) { - return false; - } - } - return true; -} -} // namespace - -void ExprEvalExecutor::compile(ViewOp* view_op) { - FUSER_PERF_SCOPE("ExprEvalExecutor::compile(ViewOp* view_op"); - std::vector sizes; - - for (auto id : view_op->out()->getLogicalDomain()) { - // Ignore sharded dimensions - if (id->isDeviceDim()) { - sizes.push_back(FusionGuard::getCurFusion()->oneVal()); - continue; - } - - // Constant reshape specified dimensions - auto id_size = id->getMaybeExpandedExtent(); - if (id_size->isConstInt() && id_size->definition() != nullptr) { - sizes.push_back( - IrBuilder::create(id_size->evaluate().as())); - continue; - } - - sizes.push_back(id_size); - } - - int missing_vals = std::count_if(sizes.begin(), sizes.end(), [](Val* size) { - return !size->isConstScalar(); - }); - - // Record which vals need to be inferred and what input bindings we need to - // infer them. - if (missing_vals > 1) { - for (auto size : sizes) { - needed_integer_scalars.pushBack(size); - } - } - - ViewInfo view_info = {sizes, missing_vals <= 1, isContiguous(view_op->in())}; - - view_infos[view_op] = view_info; -} - -at::Tensor ExprEvalExecutor::run( - ViewOp* view_op, - ExpressionEvaluator& expr_eval) { - FUSER_PERF_SCOPE("ExprEvalExecutor::run(ViewOp* view_op"); - auto view_info_it = view_infos.find(view_op); - NVF_ERROR( - view_info_it != view_infos.end(), - "Error running ViewOp, it wasn't compiled."); - ViewInfo& view_info = view_info_it->second; - - std::vector sizes; - for (auto size : view_info.output_view_sizes) { - if (size->isConstInt()) { - sizes.push_back(size->value().as()); - } else if (view_info.use_neg_1) { - sizes.push_back(-1); - } else { - expr_eval.evaluate(size).as(); - } - } - - auto input = expr_eval.evaluate(view_op->in()).as(); - - if (view_info.use_at_view) { - return input.view(sizes); - } - return input.reshape(sizes); -} - -void ExprEvalExecutor::compile(LoadStoreOp* ld_st_op) { - FUSER_PERF_SCOPE("ExprEvalExecutor::compile(LoadStoreOp* ld_st_op"); - if (TensorView* out_tv = dynamic_cast(ld_st_op->out())) { - if (out_tv->hasRoot()) { - std::optional> permutation = - ir_utils::computePermutation( - out_tv->getRootDomain(), out_tv->getLogicalDomain()); - NVF_ERROR( - permutation.has_value(), - "The logical domain of a Set.Permute is supposed to be a permutation of the root domain: ", - out_tv->toString()); - permutation_orders[ld_st_op] = *permutation; - } - } -} - -at::Tensor ExprEvalExecutor::run(LoadStoreOp* ld_st_op, at::Tensor input) { - FUSER_PERF_SCOPE("ExprEvalExecutor::run(LoadStoreOp* ld_st_op"); - auto permute_it = permutation_orders.find(ld_st_op); - if (permute_it == permutation_orders.end()) { - return input; - } - return input.permute(permute_it->second); -} - -} // namespace nvfuser diff --git a/csrc/runtime/expr_eval_exec.h b/csrc/runtime/expr_eval_exec.h deleted file mode 100644 index c93f6d61511..00000000000 --- a/csrc/runtime/expr_eval_exec.h +++ /dev/null @@ -1,119 +0,0 @@ -// clang-format off -/* - * SPDX-FileCopyrightText: Copyright (c) 2023-present NVIDIA CORPORATION & AFFILIATES. - * All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - */ -// clang-format on -#pragma once - -#include -#include -#include - -namespace nvfuser { - -class ExprEvalExecutor : public ExecutorAbstract { - public: - ExprEvalExecutor( - int64_t fusion_id = 0, - int64_t concrete_id = 0, - int64_t runtime_id = 0, - int64_t group_id = 0) - : ExecutorAbstract(fusion_id, concrete_id, runtime_id, group_id) {} - - // Returns true if all fusion outputs are expression evaluated. - static bool supported(Fusion* fusion); - - void compile(Fusion* fusion); - - bool isCompiled() const override; - - NVF_API std::vector run( - KernelArgumentHolder& args, - std::vector outputs = {}); - - const std::unique_ptr& fusion() { - return fusion_; - } - - private: - std::unique_ptr fusion_; - - // Expressions to evaluate - std::vector exprs_; - - struct ViewInfo { - // Sizes of the output of view ops, only one value can be unknown at it gets - // processed in aten as a -1 size, every other dim is a constant positive - // integer value. - std::vector output_view_sizes; - // PyTorch's API defines all output shapes as a constant known size except - // upto 1 which can be easily inferred based on the input numel and the rest - // of the ouput sizes. nvFuser can have dynamic reshape operations where the - // output sizes are inferred through split and merge operations on IDs. If - // use_neg_1 is true then all values except up to one are constant values. - bool use_neg_1 = false; - // at::view can be used on contiguous tensors and is faster than - // at::reshape. Since we know at compile time if the tensor is contiguous - // then we can route evaluation to view. - bool use_at_view = false; - }; - - std::unordered_map view_infos; - - // Permute map, stores permutation axes if a LoadStoreOp requires them. - std::unordered_map> permutation_orders; - - struct TVInfo { - TensorView* tv; - uint64_t fusion_input_pos; - uint64_t logical_dim_pos; - - bool operator==(const TVInfo& other) const { - return tv == other.tv && fusion_input_pos == other.fusion_input_pos && - logical_dim_pos == other.logical_dim_pos; - } - }; - - // For use with VectorOfUniqueEntries - struct TVInfoHash { - std::size_t operator()(const TVInfo& info) const { - std::size_t hash = 0; - hash ^= std::hash()(info.tv); - hash ^= std::hash()(info.fusion_input_pos); - hash ^= std::hash()(info.logical_dim_pos) << 8; - return hash; - } - }; - - // Expr eval exec only shallowly binds inputs. This means all sizes of each - // tensor are not bound. During compilation information about which size - // information needs to be pulled and bound are tracked. References entries in - // extent_to_tv_info map. - VectorOfUniqueEntries tv_sizes_to_bind; - std::unordered_map extent_to_tv_info; - - // Since input tensor views could be from an intermediate segmentation their - // logical domains could be a function of iter domains of a previous fusions. - // This means an input tensor could have an iter domain for example: - // iS24{( ceilDiv(( i0 * i2 ), 3) )} where i0 and i2 are not "inputs" to - // the fusion. This means we want to bind a the size of the input tensor to - // the entire scalar, not to i0 and i2. This unordered set will contain all - // input scalars and all logical domain scalars of input tensors, to resolve - // how to infer all necessary scalars for the fusion. - VectorOfUniqueEntries all_potential_input_scalars; - - // The scalars that need to be infered during execution. - VectorOfUniqueEntries needed_integer_scalars; - // Goes to val's inputs and check if it's from a TensorView, if so it fills - // tv_sizes_to_bind for those inputs. - void findAndBindInputTVExtentsFrom(VectorOfUniqueEntries vals); - - void compile(ViewOp* view_op); - at::Tensor run(ViewOp* view_op, ExpressionEvaluator& expr_eval); - - void compile(LoadStoreOp* ld_st_op); - at::Tensor run(LoadStoreOp* ld_st_op, at::Tensor input); -}; -} // namespace nvfuser