diff --git a/csrc/device_lower/utils.cpp b/csrc/device_lower/utils.cpp index 99b22087589..f9f9e848fff 100644 --- a/csrc/device_lower/utils.cpp +++ b/csrc/device_lower/utils.cpp @@ -151,6 +151,7 @@ bool isTvOp(const Expr* expr) { LoadStoreOp, MatmulOp, MmaOp, + LinearOp, BroadcastOp, SqueezeOp, ExpandOp, diff --git a/csrc/dispatch.h b/csrc/dispatch.h index c2150dda35e..714b37c3e15 100644 --- a/csrc/dispatch.h +++ b/csrc/dispatch.h @@ -108,6 +108,7 @@ class Val; f(Swizzle2D); \ f(Resize); \ f(MatmulOp); \ + f(LinearOp); \ f(Communication); #define DISPATCH_FOR_ALL_KIR_EXPRS(f) \ f(Allocate); \ diff --git a/csrc/ir/internal_nodes.h b/csrc/ir/internal_nodes.h index ba9608909db..0ddd0a704a1 100644 --- a/csrc/ir/internal_nodes.h +++ b/csrc/ir/internal_nodes.h @@ -2288,4 +2288,51 @@ class MatmulOp : public Expr { const std::vector& inputs) const override; }; +// Linear node with same functionality as F.linear +// (https://pytorch.org/docs/stable/generated/torch.nn.functional.linear.html#torch.nn.functional.linear) +class LinearOp : public Expr { + public: + using Expr::Expr; + + LinearOp(IrBuilderPasskey, Val* out, Val* in_a, Val* in_b, Val* bias); + + NVFUSER_DECLARE_CLONE_AND_CREATE + + const char* getOpString() const override { + return "LinearOp"; + } + + std::string toString(int indent_size = 0) const override; + std::string toInlineString(int indent_size = 0) const override; + + Val* out() const { + return output(0); + } + + Val* inA() const { + return input(0); + } + + Val* inB() const { + return input(1); + } + + Val* bias() const { + if (has_bias()) { + return input(2); + } else { + return nullptr; + } + } + + std::vector evaluate( + const ExpressionEvaluator& ee, + const std::vector& inputs) const override; + + private: + bool has_bias() const { + return inputs().size() == 3; + } +}; + } // namespace nvfuser diff --git a/csrc/ir/nodes.cpp b/csrc/ir/nodes.cpp index 879504c07e2..e3369dad1b7 100644 --- a/csrc/ir/nodes.cpp +++ b/csrc/ir/nodes.cpp @@ -4501,4 +4501,51 @@ std::vector MatmulOp::evaluate( return {at::matmul(a, b)}; } +LinearOp::LinearOp( + IrBuilderPasskey passkey, + Val* out, + Val* in_a, + Val* in_b, + Val* bias) + : Expr(passkey) { + addOutput(out); + addInput(in_a); + addInput(in_b); + + if (bias != nullptr) { + addInput(bias); + } +} + +NVFUSER_DEFINE_CLONE_AND_CREATE(LinearOp) + +std::string LinearOp::toString(int indent_size) const { + std::stringstream ss; + indent(ss, indent_size) << out()->toString() << "\n"; + indent(ss, indent_size + 1) << " = linear(" << inA()->toString() << ",\n"; + indent(ss, indent_size + 1) << " " << inB()->toString(); + if (has_bias()) { + indent(ss, indent_size + 1) << ",\n " << bias()->toString(); + } + indent(ss, indent_size + 1) << ")\n"; + return ss.str(); +} + +std::string LinearOp::toInlineString(int indent_size) const { + NVF_CHECK(false, "Tensor op can not be printed inline"); +} + +std::vector LinearOp::evaluate( + const ExpressionEvaluator& ee, + const std::vector& inputs) const { + const auto a = inputs.at(0).as(); + const auto b = inputs.at(1).as(); + + if (has_bias()) { + const auto bias = inputs.at(2).as(); + return {at::linear(a, b, bias)}; + } + return {at::linear(a, b)}; +} + } // namespace nvfuser diff --git a/csrc/ops/composite.cpp b/csrc/ops/composite.cpp index cdeeaecb624..4685e0cf5ab 100644 --- a/csrc/ops/composite.cpp +++ b/csrc/ops/composite.cpp @@ -54,42 +54,92 @@ TensorView* dropout_backward(TensorView* dy, TensorView* mask, Val* scale) { return dx; } -TensorView* linear(TensorView* a, TensorView* b, TensorView* bias) { - // TODO: Support 1+ dimensional A. +namespace { + +static TensorView* newForLinear( + TensorView* input, + TensorView* weight, + TensorView* bias) { + auto input_domain = + TensorDomain::noReductions(input->getMaybeRFactorDomain()); + auto weight_domain = + TensorDomain::noReductions(weight->getMaybeRFactorDomain()); + + // Linear: a = {*, in_features}, b = {out_features, in_features} / + // {in_features}.The linear output is {*, (out_features), rK}. + // The first out_size -2 dimensions are as the first input, followed by + // out_features (if present) and an additional reduction axis K. + auto ndims_out = input_domain.size() + weight_domain.size() - 1; + + const std::vector& mapping_a = + ops::mapLinearOpIterDomains(input_domain, MatmulRole::INPUT_A, ndims_out); + const std::vector& mapping_b = ops::mapLinearOpIterDomains( + weight_domain, MatmulRole::INPUT_B, ndims_out); + std::vector mapping_bias(ndims_out, nullptr); + if (bias != nullptr) { + auto bias_domain = + TensorDomain::noReductions(bias->getMaybeRFactorDomain()); + mapping_bias = ops::mapLinearOpIterDomains( + bias_domain, MatmulRole::INPUT_C, ndims_out); + } + + std::vector out_domain(ndims_out, nullptr); + + for (auto idx : c10::irange(ndims_out - 1)) { + out_domain[idx] = ops::newOutputIterDomain( + {mapping_a.at(idx), mapping_b.at(idx), mapping_bias.at(idx)}); + } + // Specify the iterdomain for K as reduction + out_domain[ndims_out - 1] = ops::newOutputIterDomain( + {mapping_a.back(), mapping_b.back()}, + /*force_iter_type=*/IterType::Reduction); + + TensorDomain* td = IrBuilder::create( + out_domain, TensorDomain::getContiguityFilledWith(out_domain, true)); + + return IrBuilder::create(td, input->dtype()); +} + +} // namespace + +TensorView* linear(TensorView* input, TensorView* weight, TensorView* bias) { + auto input_ndims = + TensorDomain::noReductions(input->getMaybeRFactorDomain()).size(); + NVF_CHECK(input_ndims > 0, "Input A must be atleast 1D."); + + auto weight_ndims = + TensorDomain::noReductions(weight->getMaybeRFactorDomain()).size(); NVF_CHECK( - (a->nDims() == 2 && b->nDims() == 2), - "Only 2-D Inputs and Weights are currently supported in Linear!"); - - std::vector bcast_dims(a->nDims() + 1, false); - // A: [M, Bcast, K] - // B: [Bcast, N, K] - bcast_dims.at(bcast_dims.size() - 2) = true; - auto* tv0b = broadcast(a, bcast_dims); - bcast_dims.at(bcast_dims.size() - 2) = false; - bcast_dims.at(bcast_dims.size() - 3) = true; - auto* tv1b = broadcast(b, bcast_dims); + weight_ndims == 1 || weight_ndims == 2, + "Input B must be a 1D / 2D tensor."); + // Note: This constraint is not documented but F.linear errors out if bias is + // given with 1D weights. NVF_CHECK( - a->getDataType().value() == b->getDataType().value(), - "data types of inputs to matmul don't match"); - - auto* output = fusedMultiplySum(tv0b, tv1b, {-1}); - if (bias) { - NVF_CHECK( - (bias->nDims() <= a->nDims()), "bias should be broadcastable to A"); - NVF_CHECK( - a->getDataType().value() == bias->getDataType().value(), - "bias doesn't match input/weight dtype"); - auto* bias_with_cast = maybeCastOp(output->getDataType().value(), bias); - auto* bcast_bias = ops::maybeBroadcast({output, bias_with_cast})[1]; - auto* bias_output = add(output, bcast_bias); - return maybeCastOp(a->getDataType().value(), bias_output); - } - return maybeCastOp(a->getDataType().value(), output); + weight_ndims == 2 || bias == nullptr, + "Expected B to be a 2D matrix if bias is given, got 1D.") + + NVF_CHECK( + input->dtype() == weight->dtype(), + "Expected input and weight dtypes to have the same dtype, got: ", + input->dtype(), + " and ", + weight->dtype()); + + NVF_CHECK( + bias == nullptr || bias->dtype() == input->dtype(), + "Expected bias to have the same dtype as A and B, got: ", + bias->dtype(), + " and ", + input->dtype()); + // For all other cases, create a new LinearOp + TensorView* out = newForLinear(input, weight, bias); + IrBuilder::create(out, input, weight, bias); + return out; } -TensorView* linear(TensorView* a, TensorView* b) { - return linear(a, b, nullptr /*bias*/); +TensorView* linear(TensorView* tv_a, TensorView* tv_b) { + return linear(tv_a, tv_b, /*bias=*/nullptr); } LstmResult lstm( @@ -293,15 +343,8 @@ static TensorView* newForMatmul(TensorView* tv_a, TensorView* tv_b) { orig_domain_b, MatmulRole::INPUT_B, ndims_out); for (auto idx : c10::irange(ndims_out - 1)) { - std::vector input_ids; - input_ids.reserve(2); - if (mapping_a[idx] != nullptr) { - input_ids.emplace_back(mapping_a[idx]); - } - if (mapping_b[idx] != nullptr) { - input_ids.emplace_back(mapping_b[idx]); - } - out_domain[idx] = ops::newOutputIterDomain(input_ids); + out_domain[idx] = + ops::newOutputIterDomain({mapping_a.at(idx), mapping_b.at(idx)}); } out_domain[ndims_out - 1] = ops::newOutputIterDomain( diff --git a/csrc/ops/composite.h b/csrc/ops/composite.h index fa617e75154..0ef555ebc59 100644 --- a/csrc/ops/composite.h +++ b/csrc/ops/composite.h @@ -47,17 +47,17 @@ NVF_API LstmResult lstm( TensorView* cell_x, TensorView* out_x); -// Linear functions which takes in two tensors of shapes A[M,K] and -// B[N,K]. Takes in a options bias of shape [N] and performs -// out = A * B_Transpose + bias. The output dtype matches the dtype -// ofthe inputs which should match. -TensorView* linear(TensorView* a, TensorView* b, TensorView* bias); +// Linear functions which takes in two tensors of shapes input[* , in_features], +// weight[out_features, in_features] / [in_features] and an optional bias of +// shape [out_features] or 0D scalar. Bias can only be given if weight is a 2-D +// tensor. +TensorView* linear(TensorView* input, TensorView* weight, TensorView* bias); // This is an implementation detail to reflect when linear is called // without a bias. This calls the above function. We use this function // since it simplifies creating a Python API which takes optional arguments. // Other options include using lambdas or creating a new RecordFunctor for // Linear. -TensorView* linear(TensorView* a, TensorView* b); +TensorView* linear(TensorView* input, TensorView* weight); NVF_API TensorView* sign(TensorView* x); NVF_API Val* sign(Val* x); diff --git a/csrc/ops/utils.cpp b/csrc/ops/utils.cpp index 16928f25002..43570a50ede 100644 --- a/csrc/ops/utils.cpp +++ b/csrc/ops/utils.cpp @@ -221,6 +221,45 @@ std::vector mapMatmulOpIterDomains( return mapping; } +std::vector mapLinearOpIterDomains( + const std::vector& input_domain, + MatmulRole input_role, + size_t out_size) { + std::vector mapping(out_size, nullptr); + auto inp_size = input_domain.size(); + + // Input A: {*, M, K} + // Input B: {*, N, K} / {K} + // Bias: {N} / {} + switch (input_role) { + case MatmulRole::INPUT_A: { + // Linear output is same as input for all but the last dimension + for (auto inx : c10::irange(inp_size - 1)) { + mapping[inx] = input_domain[inx]; + } + mapping[out_size - 1] = input_domain.back(); + break; + } + case MatmulRole::INPUT_B: { + for (auto inx : c10::irange(inp_size)) { + // Map N, K to the last two positions of the output. + mapping[out_size - 1 - inx] = input_domain[inp_size - 1 - inx]; + } + break; + } + case MatmulRole::INPUT_C: { + if (inp_size > 0) { + // Bias is 1D tensor of shape {out_features} + mapping[out_size - 2] = input_domain[0]; + } + break; + } + default: + NVF_ERROR("Unexpected input type."); + } + return mapping; +} + // Adding these pragmas since gcc-12.2.1 // incorrectly reports a warning with the use of evaluate #if defined(__GNUC__) && !defined(__clang__) @@ -228,7 +267,7 @@ std::vector mapMatmulOpIterDomains( #pragma GCC diagnostic ignored "-Wfree-nonheap-object" #endif IterDomain* newOutputIterDomain( - const std::vector& ids, + const std::vector& input_ids, const std::optional force_iter_type) { // For the start and stop offsets, take the maximum of input axes. // For now, the offsets of both start and stop are always integer @@ -242,6 +281,16 @@ IterDomain* newOutputIterDomain( Val* expanded_extent_val = nullptr; std::optional iter_type = std::nullopt; + std::vector ids; + ids.reserve(input_ids.size()); + + // Filter out any nullptrs + std::copy_if( + input_ids.begin(), + input_ids.end(), + std::back_inserter(ids), + [](IterDomain* id) { return id != nullptr; }); + for (auto id : ids) { if (id->isBroadcast()) { if (id->hasExpandedExtent()) { diff --git a/csrc/ops/utils.h b/csrc/ops/utils.h index 47ac94b7d04..5c0982bb39e 100644 --- a/csrc/ops/utils.h +++ b/csrc/ops/utils.h @@ -46,11 +46,32 @@ IterType promoteIterType(IterType type1, IterType type2); // Mapping B: {nullptr, id_N}) // 3. A/B are atleast 1D and one of them is > 2D: [B, M, K] x [K, N] -> [B, M, // N] (Mapping A: {id_B, id_M, nullptr}, Mapping B: {nullptr, nullptr, id_N}) +// Args: +// 1. input_domain: root/rfactor domain without reductions for any input to +// MatmulOp +// 2. input_role: Specifies if the input is A / B (MatmulRole::Input_A/Input_B) +// 3: out_size: MatmulOp output dimension (input and output may not be the same +// size). std::vector mapMatmulOpIterDomains( const std::vector& input_domain, MatmulRole input_role, size_t out_size); +// For LinearOp, the output is the same as the first input (A[*, +// in_features])for all but the last dimension. If the second input is 2D +// (B[out_features, in_features]), the last dimension of output is out_features. +// If bias is 1D (bias[out_features]) it maps to the last dimension of the +// output. Args: +// 1. input_domain: root/rfactor domain without reductions for any input to +// LinearOp +// 2. input_role: Specifies if the input is A / B / Bias +// (MatmulRole::Input_A/Input_B/Input_C) 3: out_size: LinearOp output dimension +// (input and output may not be the same size). +std::vector mapLinearOpIterDomains( + const std::vector& input_domain, + MatmulRole input_role, + size_t out_size); + // Takes a vector of aligned input iterdomains to create the output iterdomain. // This is used if the input iterdomains are not trivially mapped to the output // iterdomains. For eg: MatmulOp. If given, the forced_iter_type argument will diff --git a/csrc/root_domain_map.cpp b/csrc/root_domain_map.cpp index 0f6e01ded31..01b2adbd4c6 100644 --- a/csrc/root_domain_map.cpp +++ b/csrc/root_domain_map.cpp @@ -165,6 +165,21 @@ std::unordered_map PairwiseRootDomainMap::map( } }; + // Assumes producer and consumer IDs to be trivially aligned and adds them to + // domain map. + auto pairwiseMapAllIds = [&](std::vector producer_ids, + std::vector consumer_ids) { + NVF_ERROR(producer_ids.size() == consumer_ids.size()); + for (auto idx : c10::irange(consumer_ids.size())) { + IterDomain* producer_id = producer_ids.at(idx); + IterDomain* consumer_id = consumer_ids.at(idx); + if (producer_id == nullptr) { + continue; + } + updatePairwiseRootDomainMap(producer_id, consumer_id); + } + }; + // For MatmulOp, use the corresponding mapped input iterdomains. if (MatmulOp* op = dynamic_cast(consumer_tv_->definition())) { // Check if the producer is lhs/rhs input @@ -183,18 +198,35 @@ std::unordered_map PairwiseRootDomainMap::map( // maps to the third output iterdomain. const std::vector& aligned_producer_ids = ops::mapMatmulOpIterDomains(producer_root, input_role, out_size); + pairwiseMapAllIds(aligned_producer_ids, consumer_root); + return dom_map; + } - NVF_ERROR(aligned_producer_ids.size() == consumer_root.size()); + if (LinearOp* op = dynamic_cast(consumer_tv_->definition())) { + auto out_size = consumer_root.size(); - for (auto inx : c10::irange(out_size)) { - IterDomain* producer_id = aligned_producer_ids.at(inx); - IterDomain* consumer_id = consumer_root.at(inx); - if (producer_id == nullptr) { - continue; - } - updatePairwiseRootDomainMap(producer_id, consumer_id); + // Check if the producer is A, B or bias. + std::optional input_role = std::nullopt; + if (producer->sameAs(op->inA()->as()->domain())) { + input_role = MatmulRole::INPUT_A; + } else if (producer->sameAs(op->inB()->as()->domain())) { + input_role = MatmulRole::INPUT_B; + } else if (producer->sameAs(op->bias()->as()->domain())) { + input_role = MatmulRole::INPUT_C; + } else { + NVF_ERROR(false, "Producer did not match any LinearOp input.") } + // LinearOp: + // inputs (INPUT_A) = {*, in_features} + // weight (INPUT_B) = {out_features, in_features} / {in_features} + // bias (INPUT_C) = {out_features} / {} + // output = {*, out_features} / {*} + + const std::vector& aligned_producer_ids = + ops::mapLinearOpIterDomains( + producer_root, input_role.value(), out_size); + pairwiseMapAllIds(aligned_producer_ids, consumer_root); return dom_map; } diff --git a/csrc/scheduler/expr_eval_sched.cpp b/csrc/scheduler/expr_eval_sched.cpp index c600b2f0bea..b25a290ce99 100644 --- a/csrc/scheduler/expr_eval_sched.cpp +++ b/csrc/scheduler/expr_eval_sched.cpp @@ -13,15 +13,16 @@ namespace nvfuser { -// Check if the fusion has a single MatmulOp node +// Check if the fusion has a single MatmulOp/LinearOp node bool ExprEvalScheduler::canScheduleCompileTime(Fusion* fusion) { auto exprs = fusion->exprs(); - if (exprs.size() == 1 && exprs.front()->isA()) { + if (exprs.size() == 1 && + (exprs.front()->isA() || exprs.front()->isA())) { return true; } scheduler_debug_utils::canScheduleRejectReason( heuristicType(), - "Fusion must contain a single expression of type MatmulOp"); + "Fusion must contain a single expression of type MatmulOp or LinearOp"); return false; } diff --git a/tests/cpp/test_matmul_aten_evaluation.cpp b/tests/cpp/test_matmul_aten_evaluation.cpp index 94da409b504..80287793480 100644 --- a/tests/cpp/test_matmul_aten_evaluation.cpp +++ b/tests/cpp/test_matmul_aten_evaluation.cpp @@ -35,12 +35,25 @@ class MatmulATenEvaluationTest : public NVFuserTest { using Sizes = std::vector; using MatmulNodeParamType = std::tuple; -class ATenNodesParametrizedTest +class MatmulNodeParametrizedTest : public NVFuserFixtureParamTest { protected: // Allocation order set by the pass breaks matmul tests // see issue https://github.com/NVIDIA/Fuser/issues/1810 - ATenNodesParametrizedTest() : optimization_guard_(false) {} + MatmulNodeParametrizedTest() : optimization_guard_(false) {} + + private: + preseg_passes::OptimizationPassGuard + optimization_guard_; +}; + +using LinearNodeParamType = std::tuple>; +class LinearNodeParametrizedTest + : public NVFuserFixtureParamTest { + protected: + // Allocation order set by the pass breaks matmul tests + // see issue https://github.com/NVIDIA/Fuser/issues/1810 + LinearNodeParametrizedTest() : optimization_guard_(false) {} private: preseg_passes::OptimizationPassGuard @@ -410,6 +423,15 @@ TEST_F(MatmulATenEvaluationTest, LinearWithBias) { EXPECT_TRUE(at::allclose(out[0], out_ref)); } +const bool checkMapped(const ValGraph& vg, IterDomain* x, IterDomain* y) { + if (!vg.hasGroup(x) || !vg.hasGroup(y)) { + return false; + } + const ValGroup& gx = vg.toGroup(x); + const ValGroup& gy = vg.toGroup(y); + return gx.get() == gy.get(); +}; + // Check that ID exact mapping works as expected void checkMatmulOpIdMapping( Fusion* fusion, @@ -420,15 +442,6 @@ void checkMatmulOpIdMapping( const ValGraph& vg = id_model.idGraph(IdMappingMode::EXACT); vg.validateConsistency(); - const auto checkMapped = [&vg](IterDomain* x, IterDomain* y) -> bool { - if (!vg.hasGroup(x) || !vg.hasGroup(y)) { - return false; - } - const ValGroup& gx = vg.toGroup(x); - const ValGroup& gy = vg.toGroup(y); - return gx.get() == gy.get(); - }; - // If K is Broadcast then we will not have a reduction dim bool k_bcast = A->axis(-1)->isBroadcast(); int64_t red_dims = k_bcast ? 0 : 1; @@ -440,44 +453,44 @@ void checkMatmulOpIdMapping( EXPECT_EQ(output->nDims(), 0); // When K is Broadcast, we squeeze then multiply then cast instead if (!k_bcast) { - EXPECT_TRUE(checkMapped(A->axis(0), B->axis(0))); // K + EXPECT_TRUE(checkMapped(vg, A->axis(0), B->axis(0))); // K } } else if (A->nDims() > 1 && B->nDims() == 1) { // [..., iM, iK] @ [iK] = [..., iM, rK] ASSERT_EQ(output->nDims(), A->nDims() + red_dims - 1); - EXPECT_TRUE(checkMapped(A->axis(-2), output->axis(-1 - red_dims))); // M + EXPECT_TRUE(checkMapped(vg, A->axis(-2), output->axis(-1 - red_dims))); // M if (!k_bcast) { - EXPECT_TRUE(checkMapped(A->axis(-1), B->axis(0))); // K - EXPECT_TRUE(checkMapped(A->axis(-1), output->axis(-1))); // K + EXPECT_TRUE(checkMapped(vg, A->axis(-1), B->axis(0))); // K + EXPECT_TRUE(checkMapped(vg, A->axis(-1), output->axis(-1))); // K } // Check that batch dims are mapped for (int64_t i : c10::irange(output->nDims() - red_dims - 1)) { if (!A->axis(i)->isBroadcast()) { - EXPECT_TRUE(checkMapped(A->axis(i), output->axis(i))); + EXPECT_TRUE(checkMapped(vg, A->axis(i), output->axis(i))); } } } else if (A->nDims() == 1 && B->nDims() > 1) { // [iK] @ [..., iK, iN] = [..., iN, rK] ASSERT_EQ(output->nDims(), B->nDims() + red_dims - 1); - EXPECT_TRUE(checkMapped(B->axis(-1), output->axis(-1 - red_dims))); // N + EXPECT_TRUE(checkMapped(vg, B->axis(-1), output->axis(-1 - red_dims))); // N if (!k_bcast) { - EXPECT_TRUE(checkMapped(A->axis(0), B->axis(-2))); // K - EXPECT_TRUE(checkMapped(A->axis(0), output->axis(-1))); // K + EXPECT_TRUE(checkMapped(vg, A->axis(0), B->axis(-2))); // K + EXPECT_TRUE(checkMapped(vg, A->axis(0), output->axis(-1))); // K } // Check that batch dims are mapped for (int64_t i : c10::irange(output->nDims() - red_dims - 1)) { if (!B->axis(i)->isBroadcast()) { - EXPECT_TRUE(checkMapped(B->axis(i), output->axis(i))); + EXPECT_TRUE(checkMapped(vg, B->axis(i), output->axis(i))); } } } else if (A->nDims() > 1 && B->nDims() > 1) { // [..., iM, iK] @ [..., iK, iN] = [..., iM, iN, rK] ASSERT_EQ(output->nDims(), std::max(A->nDims(), B->nDims()) + red_dims); - EXPECT_TRUE(checkMapped(A->axis(-2), output->axis(-2 - red_dims))); // M - EXPECT_TRUE(checkMapped(B->axis(-1), output->axis(-1 - red_dims))); // N + EXPECT_TRUE(checkMapped(vg, A->axis(-2), output->axis(-2 - red_dims))); // M + EXPECT_TRUE(checkMapped(vg, B->axis(-1), output->axis(-1 - red_dims))); // N if (!k_bcast) { - EXPECT_TRUE(checkMapped(A->axis(-1), B->axis(-2))); // K - EXPECT_TRUE(checkMapped(A->axis(-1), output->axis(-1))); // K + EXPECT_TRUE(checkMapped(vg, A->axis(-1), B->axis(-2))); // K + EXPECT_TRUE(checkMapped(vg, A->axis(-1), output->axis(-1))); // K } // Check that batch dims are mapped // Note that A and B can have different dimensions, so here we count @@ -488,10 +501,10 @@ void checkMatmulOpIdMapping( int64_t i_b = B->nDims() - 3 - i; int64_t i_out = output->nDims() - red_dims - 3 - i; if (i_a >= 0 && !A->axis(i_a)->isBroadcast()) { - EXPECT_TRUE(checkMapped(A->axis(i_a), output->axis(i_out))); + EXPECT_TRUE(checkMapped(vg, A->axis(i_a), output->axis(i_out))); } if (i_b >= 0 && !B->axis(i_b)->isBroadcast()) { - EXPECT_TRUE(checkMapped(B->axis(i_b), output->axis(i_out))); + EXPECT_TRUE(checkMapped(vg, B->axis(i_b), output->axis(i_out))); } } } else { @@ -500,7 +513,47 @@ void checkMatmulOpIdMapping( } } -TEST_P(ATenNodesParametrizedTest, MatmulNodeConcrete) { +// Check that ID exact mapping works as expected +void checkLinearOpIdMapping( + Fusion* fusion, + TensorView* input, + TensorView* weight, + TensorView* bias, + TensorView* output) { + IdModel id_model(fusion); + const ValGraph& vg = id_model.idGraph(IdMappingMode::EXACT); + vg.validateConsistency(); + + // input: [* , in_features] + // weight: [out_features, in_features] / [out_features] + // bias (optional): [out_features]/[] + // output = [*, (out_features), rK] + + ASSERT_EQ(output->nDims(), input->nDims() + weight->nDims() - 1); + + // Check that the first input_size - 1 dims are mapped for input + for (auto i : c10::irange(input->nDims() - 1)) { + if (!input->axis(i)->isBroadcast()) { + EXPECT_TRUE(checkMapped(vg, input->axis(i), output->axis(i))); + } + } + // Check out_features dim is mapped in weight & bias if present. + if (weight->nDims() > 1) { + if (!weight->axis(0)->isBroadcast()) { + EXPECT_TRUE(checkMapped(vg, weight->axis(0), output->axis(-2))); + } + if (bias != nullptr && bias->nDims() > 0 && !bias->axis(0)->isBroadcast()) { + EXPECT_TRUE(checkMapped(vg, bias->axis(0), output->axis(-2))); + } + } + // Check mapping for reduction axis in input and weight + if (!input->axis(-1)->isBroadcast()) { + EXPECT_TRUE(checkMapped(vg, input->axis(-1), weight->axis(-1))); + EXPECT_TRUE(checkMapped(vg, input->axis(-1), output->axis(-1))); + } +} + +TEST_P(MatmulNodeParametrizedTest, MatmulNodeConcrete) { auto fusion = std::make_unique(); FusionGuard fg(fusion.get()); @@ -526,7 +579,7 @@ TEST_P(ATenNodesParametrizedTest, MatmulNodeConcrete) { EXPECT_TRUE(at::allclose(out[0], out_ref)); } -TEST_P(ATenNodesParametrizedTest, MatmulNodeSymbolic) { +TEST_P(MatmulNodeParametrizedTest, MatmulNodeSymbolic) { auto fusion = std::make_unique(); FusionGuard fg(fusion.get()); @@ -552,12 +605,111 @@ TEST_P(ATenNodesParametrizedTest, MatmulNodeSymbolic) { EXPECT_TRUE(at::allclose(out[0], out_ref)); } +TEST_P(LinearNodeParametrizedTest, LinearNodeConcrete) { + auto fusion = std::make_unique(); + FusionGuard fg(fusion.get()); + + const auto& [a_shape, b_shape, bias_shape] = GetParam(); + + auto tv0 = makeConcreteTensor(a_shape, DataType::Half); + auto tv1 = makeConcreteTensor(b_shape, DataType::Half); + TensorView* bias = nullptr; + if (bias_shape.has_value()) { + bias = makeConcreteTensor(*bias_shape, DataType::Half); + } + auto tv2 = linear(tv0, tv1, bias); + + fusion->addInput(tv0); + fusion->addInput(tv1); + if (bias_shape.has_value()) { + fusion->addInput(bias); + } + fusion->addOutput(tv2); + + checkLinearOpIdMapping(fusion.get(), tv0, tv1, bias, tv2); + + at::Tensor t0 = at::randn(a_shape, at::kHalf).cuda(); + at::Tensor t1 = at::randn(b_shape, at::kHalf).cuda(); + std::optional bias_opt = std::nullopt; + if (bias_shape.has_value()) { + bias_opt = at::randn(*bias_shape, at::kHalf).cuda(); + } + at::Tensor out_ref = at::linear(t0, t1, bias_opt); + + FusionExecutorCache fec(std::move(fusion)); + + std::vector out = {}; + if (bias_shape.has_value()) { + out = fec.runFusionWithInputs({t0, t1, bias_opt}); + } else { + out = fec.runFusionWithInputs({t0, t1}); + } + + const std::vector& executors = + fec.getMostRecentKernelRuntime()->executors(); + EXPECT_EQ(executors.size(), 1); + // Verify that fusion compilation was skipped. + EXPECT_FALSE(executors.front().hasCompiledKernel()); + + EXPECT_TRUE(at::allclose(out[0], out_ref)); +} +TEST_P(LinearNodeParametrizedTest, LinearNodeSymbolic) { + auto fusion = std::make_unique(); + FusionGuard fg(fusion.get()); + + const auto& [a_shape, b_shape, bias_shape] = GetParam(); + + auto tv0 = makeSymbolicTensor(a_shape, DataType::Half); + auto tv1 = makeSymbolicTensor(b_shape, DataType::Half); + + TensorView* bias = nullptr; + if (bias_shape.has_value()) { + bias = makeSymbolicTensor(*bias_shape, DataType::Half); + } + + auto tv2 = linear(tv0, tv1, bias); + + fusion->addInput(tv0); + fusion->addInput(tv1); + if (bias_shape.has_value()) { + fusion->addInput(bias); + } + fusion->addOutput(tv2); + + checkLinearOpIdMapping(fusion.get(), tv0, tv1, bias, tv2); + + at::Tensor t0 = at::randn(a_shape, at::kHalf).cuda(); + at::Tensor t1 = at::randn(b_shape, at::kHalf).cuda(); + std::optional bias_opt = std::nullopt; + if (bias_shape.has_value()) { + bias_opt = at::randn(*bias_shape, at::kHalf).cuda(); + } + at::Tensor out_ref = at::linear(t0, t1, bias_opt); + + FusionExecutorCache fec(std::move(fusion)); + + std::vector out = {}; + if (bias_shape.has_value()) { + out = fec.runFusionWithInputs({t0, t1, bias_opt}); + } else { + out = fec.runFusionWithInputs({t0, t1}); + } + + const std::vector& executors = + fec.getMostRecentKernelRuntime()->executors(); + EXPECT_EQ(executors.size(), 1); + // Verify that fusion compilation was skipped. + EXPECT_FALSE(executors.front().hasCompiledKernel()); + + EXPECT_TRUE(at::allclose(out[0], out_ref)); +} + constexpr int64_t b = 128, m = 64, k = 32, n = 16; // Parametrize a_shape and b_shape INSTANTIATE_TEST_SUITE_P( , - ATenNodesParametrizedTest, + MatmulNodeParametrizedTest, testing::Combine( testing::Values( Sizes({k}), @@ -574,7 +726,7 @@ INSTANTIATE_TEST_SUITE_P( // Test case where K=1 INSTANTIATE_TEST_SUITE_P( ReductionAxisIsOne, - ATenNodesParametrizedTest, + MatmulNodeParametrizedTest, testing::Combine( testing::Values( Sizes({1}), @@ -588,4 +740,43 @@ INSTANTIATE_TEST_SUITE_P( Sizes({1, 1}), Sizes({b, 1, n})))); +INSTANTIATE_TEST_SUITE_P( + LinearWithoutBias, + LinearNodeParametrizedTest, + testing::Combine( + testing::Values( + Sizes({k}), + Sizes({m, k}), + Sizes({b, m, k}), + Sizes({1, k}), + Sizes({b, 1, k})), + testing::Values(Sizes({k}), Sizes({n, k}), Sizes({1, k})), + testing::Values(std::nullopt))); + +INSTANTIATE_TEST_SUITE_P( + LinearWithBias, + LinearNodeParametrizedTest, + testing::Combine( + testing::Values( + Sizes({k}), + Sizes({m, k}), + Sizes({b, m, k}), + Sizes({1, k}), + Sizes({b, 1, k})), + testing::Values(Sizes({n, k})), + testing::Values(Sizes({}), Sizes({n})))); + +INSTANTIATE_TEST_SUITE_P( + LinearReductionAxisIsOne, + LinearNodeParametrizedTest, + testing::Combine( + testing::Values( + Sizes({1}), + Sizes({m, 1}), + Sizes({b, m, 1}), + Sizes({1, 1}), + Sizes({b, 1, 1})), + testing::Values(Sizes({n, 1})), + testing::Values(Sizes({}), Sizes({n})))); + } // namespace nvfuser diff --git a/tests/python/pytest_fusion_definitions.py b/tests/python/pytest_fusion_definitions.py index d25448510c9..5e4d164584c 100644 --- a/tests/python/pytest_fusion_definitions.py +++ b/tests/python/pytest_fusion_definitions.py @@ -21,14 +21,18 @@ def parse_inputs_fusion_definition(fd: FusionDefinition, opinfo: OpInfo, *args): nvf_args = [] - if opinfo.symbolic_parameter_list is None: - opinfo.symbolic_parameter_list = [ArgumentType.Symbolic] * len(args) - num_symbolic_parameters = len(opinfo.symbolic_parameter_list) + symbolic_parameter_list = ( + opinfo.symbolic_parameter_list + if opinfo.symbolic_parameter_list is not None + else [ArgumentType.Symbolic] * len(args) + ) + + num_symbolic_parameters = len(symbolic_parameter_list) assert num_symbolic_parameters == len( args ), f"{num_symbolic_parameters} vs {len(args)}" - for arg_type, a in zip(opinfo.symbolic_parameter_list, args): + for arg_type, a in zip(symbolic_parameter_list, args): if arg_type == ArgumentType.Symbolic: if isinstance(a, torch.Tensor): nvf_args.append(fd.from_pytorch(a)) diff --git a/tests/python/pytest_input_generators.py b/tests/python/pytest_input_generators.py index 8237381d6ed..137dab7c229 100644 --- a/tests/python/pytest_input_generators.py +++ b/tests/python/pytest_input_generators.py @@ -1524,18 +1524,53 @@ def linear_input_generator( requires_grad=requires_grad, ) - def multiply_range(maximum, step): - assert maximum % step == 0 - num_steps = int(math.log(maximum, step)) - return tuple( - map(pow, itertools.repeat(step, num_steps), range(1, num_steps + 1)) + B = 64 + M = 512 + N = 256 + K = 32 + + # Cases without bias + shapes_input = ((K), (M, K), (B, M, K), (B, 1, M, K)) + shapes_weight = ((K), (N, K), (1, K)) + for shape_input, shape_weight in itertools.product(shapes_input, shapes_weight): + yield SampleInput(make_arg(shape_input), make_arg(shape_weight)) + + # Cases with bias + shape_weight = (N, K) + shapes_bias = ((), (N,)) + for shape_input, shape_bias in itertools.product(shapes_input, shapes_bias): + yield SampleInput( + make_arg(shape_input), make_arg(shape_weight), make_arg(shape_bias) ) - # Ranges of tensor sizes: 8, 64, 512, 4096, 32768, ... - # Use a Cartesian product to create a wide range of matrix shapes - # I'll stop at 512 as possible numerical difference may show up. - M, N, K = itertools.repeat(multiply_range(512, 8), 3) - for M, N, K in itertools.product(M, N, K): - lhs_shape = (M, K) - rhs_shape = (N, K) - yield (SampleInput(make_arg(lhs_shape), make_arg(rhs_shape), make_arg((N,)))) + +def linear_error_generator( + op, dtype=torch.float32, requires_grad: bool = False, **kwargs +): + make_arg = partial( + make_tensor, device="cuda", dtype=dtype, requires_grad=requires_grad + ) + # shapes, dim, exception type, exception string + M = 512 + N = 256 + K = 32 + + bias_with_1dweight = ( + ((M, K), (K), (N)), + RuntimeError, + "Expected B to be a 2D matrix if bias is given, got 1D.", + ) + + mismatched_bias_extent = ( + ((M, K), (1, K), (N)), + RuntimeError, + f"The expanded size of the tensor (1) must match the existing size ({N}) at non-singleton dimension 1. Target sizes: [{M}, 1]. Tensor sizes: [{N}]", + ) + + error_cases = [bias_with_1dweight, mismatched_bias_extent] + + for input_shapes, ex_type, ex_str in error_cases: + shape_input, shape_weight, shape_bias = input_shapes + yield SampleInput( + make_arg(shape_input), make_arg(shape_weight), make_arg(shape_bias) + ), ex_type, ex_str diff --git a/tests/python/pytest_opinfos.py b/tests/python/pytest_opinfos.py index 5d69f57891a..480810d6927 100644 --- a/tests/python/pytest_opinfos.py +++ b/tests/python/pytest_opinfos.py @@ -50,6 +50,7 @@ where_error_generator, matmul_input_generator, linear_input_generator, + linear_error_generator, ) from pytest_utils import ( bool_int_dtypes, @@ -1133,6 +1134,7 @@ def torch_reshape_sym_fn(input_tensor, output_shaped_tensor): else (torch.float16,) ), sample_input_generator=linear_input_generator, + error_input_generator=linear_error_generator, reference=torch.nn.functional.linear, ) linear_ops.append(linear_opinfo) diff --git a/tests/python/pytest_ops.py b/tests/python/pytest_ops.py index 713f7dc35ef..690e0294eca 100644 --- a/tests/python/pytest_ops.py +++ b/tests/python/pytest_ops.py @@ -26,12 +26,16 @@ def parse_args_fusion_execution(opinfo: OpInfo, *args): if len(args) == 0: return [] - if opinfo.symbolic_parameter_list is None: - opinfo.symbolic_parameter_list = [ArgumentType.Symbolic] * len(args) - assert len(opinfo.symbolic_parameter_list) == len(args) + symbolic_parameter_list = ( + opinfo.symbolic_parameter_list + if opinfo.symbolic_parameter_list is not None + else [ArgumentType.Symbolic] * len(args) + ) + + assert len(symbolic_parameter_list) == len(args) result = [] - for arg_type, a in zip(opinfo.symbolic_parameter_list, args): + for arg_type, a in zip(symbolic_parameter_list, args): if arg_type == ArgumentType.Symbolic: if isinstance(a, list) and all(map(is_tensor, a)): result.extend(a) @@ -205,11 +209,11 @@ def errors_test_fn( fd.execute(parse_args_fusion_execution(nvf_op, *sample.args)) -# A pair of parentheses () represents a capture group in regex. +# A pair of parentheses ()/[] represents a capture group in regex. # Escape parenthesis in regex string to match raw characters. def _regex_escape_parenthesis(a: str) -> str: - b = a.replace(r"(", r"\(") - return b.replace(r")", r"\)") + b = a.replace(r"[", r"\[").replace(r"]", r"\]") + return b.replace(r"(", r"\(").replace(r")", r"\)") @create_op_test(tuple(op for op in opinfos if op.error_input_generator is not None)) diff --git a/tests/python/test_python_frontend.py b/tests/python/test_python_frontend.py index 41664ccb0c1..00b78c89757 100644 --- a/tests/python/test_python_frontend.py +++ b/tests/python/test_python_frontend.py @@ -2408,7 +2408,6 @@ def test_linear(self): k = 8 bias0d = torch.tensor(3.14, device="cuda", dtype=torch.float16) bias1d = torch.randn(n, device="cuda", dtype=torch.float16) - bias2d = torch.rand(m, n, device="cuda", dtype=torch.float16) inputs_mk_nk = [ torch.randn(m, k, device="cuda", dtype=torch.float16), @@ -2446,7 +2445,7 @@ def fusion_func( fd.add_output(t_out) in_tensors = [inputs_mk_nk, inputs_mk_kn, inputs_km_nk, inputs_km_kn] - use_bias = [None, bias0d, bias1d, bias2d] + use_bias = [None, bias0d, bias1d] for [inp, wt], use_bias in list(itertools.product(in_tensors, use_bias)): with self.subTest(inp=inp, wt=wt, use_bias=use_bias): input_tensors = ( diff --git a/version.txt b/version.txt index abd410582de..3a4036fb450 100644 --- a/version.txt +++ b/version.txt @@ -1 +1 @@ -0.2.4 +0.2.5