Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions csrc/device_lower/utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,7 @@ bool isTvOp(const Expr* expr) {
LoadStoreOp,
MatmulOp,
MmaOp,
LinearOp,
BroadcastOp,
SqueezeOp,
ExpandOp,
Expand Down
1 change: 1 addition & 0 deletions csrc/dispatch.h
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,7 @@ class Val;
f(Swizzle2D); \
f(Resize); \
f(MatmulOp); \
f(LinearOp); \
f(Communication);
#define DISPATCH_FOR_ALL_KIR_EXPRS(f) \
f(Allocate); \
Expand Down
47 changes: 47 additions & 0 deletions csrc/ir/internal_nodes.h
Original file line number Diff line number Diff line change
Expand Up @@ -2288,4 +2288,51 @@ class MatmulOp : public Expr {
const std::vector<PolymorphicValue>& inputs) const override;
};

// Linear node with same functionality as F.linear
// (https://pytorch.org/docs/stable/generated/torch.nn.functional.linear.html#torch.nn.functional.linear)
class LinearOp : public Expr {
public:
using Expr::Expr;

LinearOp(IrBuilderPasskey, Val* out, Val* in_a, Val* in_b, Val* bias);

NVFUSER_DECLARE_CLONE_AND_CREATE

const char* getOpString() const override {
return "LinearOp";
}

std::string toString(int indent_size = 0) const override;
std::string toInlineString(int indent_size = 0) const override;

Val* out() const {
return output(0);
}

Val* inA() const {
return input(0);
}

Val* inB() const {
return input(1);
}

Val* bias() const {
if (has_bias()) {
return input(2);
} else {
return nullptr;
}
}

std::vector<PolymorphicValue> evaluate(
const ExpressionEvaluator& ee,
const std::vector<PolymorphicValue>& inputs) const override;

private:
bool has_bias() const {
return inputs().size() == 3;
}
};

} // namespace nvfuser
47 changes: 47 additions & 0 deletions csrc/ir/nodes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4501,4 +4501,51 @@ std::vector<PolymorphicValue> MatmulOp::evaluate(
return {at::matmul(a, b)};
}

LinearOp::LinearOp(
IrBuilderPasskey passkey,
Val* out,
Val* in_a,
Val* in_b,
Val* bias)
: Expr(passkey) {
addOutput(out);
addInput(in_a);
addInput(in_b);

if (bias != nullptr) {
addInput(bias);
}
}

NVFUSER_DEFINE_CLONE_AND_CREATE(LinearOp)

std::string LinearOp::toString(int indent_size) const {
std::stringstream ss;
indent(ss, indent_size) << out()->toString() << "\n";
indent(ss, indent_size + 1) << " = linear(" << inA()->toString() << ",\n";
indent(ss, indent_size + 1) << " " << inB()->toString();
if (has_bias()) {
indent(ss, indent_size + 1) << ",\n " << bias()->toString();
}
indent(ss, indent_size + 1) << ")\n";
return ss.str();
}

std::string LinearOp::toInlineString(int indent_size) const {
NVF_CHECK(false, "Tensor op can not be printed inline");
}

std::vector<PolymorphicValue> LinearOp::evaluate(
const ExpressionEvaluator& ee,
const std::vector<PolymorphicValue>& inputs) const {
const auto a = inputs.at(0).as<at::Tensor>();
const auto b = inputs.at(1).as<at::Tensor>();

if (has_bias()) {
const auto bias = inputs.at(2).as<at::Tensor>();
return {at::linear(a, b, bias)};
}
return {at::linear(a, b)};
}

} // namespace nvfuser
123 changes: 83 additions & 40 deletions csrc/ops/composite.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -54,42 +54,92 @@ TensorView* dropout_backward(TensorView* dy, TensorView* mask, Val* scale) {
return dx;
}

TensorView* linear(TensorView* a, TensorView* b, TensorView* bias) {
// TODO: Support 1+ dimensional A.
namespace {

static TensorView* newForLinear(
TensorView* input,
TensorView* weight,
TensorView* bias) {
auto input_domain =
TensorDomain::noReductions(input->getMaybeRFactorDomain());
auto weight_domain =
TensorDomain::noReductions(weight->getMaybeRFactorDomain());

// Linear: a = {*, in_features}, b = {out_features, in_features} /
// {in_features}.The linear output is {*, (out_features), rK}.
// The first out_size -2 dimensions are as the first input, followed by
// out_features (if present) and an additional reduction axis K.
auto ndims_out = input_domain.size() + weight_domain.size() - 1;

const std::vector<IterDomain*>& mapping_a =
ops::mapLinearOpIterDomains(input_domain, MatmulRole::INPUT_A, ndims_out);
const std::vector<IterDomain*>& mapping_b = ops::mapLinearOpIterDomains(
weight_domain, MatmulRole::INPUT_B, ndims_out);
std::vector<IterDomain*> mapping_bias(ndims_out, nullptr);
if (bias != nullptr) {
auto bias_domain =
TensorDomain::noReductions(bias->getMaybeRFactorDomain());
mapping_bias = ops::mapLinearOpIterDomains(
bias_domain, MatmulRole::INPUT_C, ndims_out);
}

std::vector<IterDomain*> out_domain(ndims_out, nullptr);

for (auto idx : c10::irange(ndims_out - 1)) {
out_domain[idx] = ops::newOutputIterDomain(
{mapping_a.at(idx), mapping_b.at(idx), mapping_bias.at(idx)});
}
// Specify the iterdomain for K as reduction
out_domain[ndims_out - 1] = ops::newOutputIterDomain(
{mapping_a.back(), mapping_b.back()},
/*force_iter_type=*/IterType::Reduction);

TensorDomain* td = IrBuilder::create<TensorDomain>(
out_domain, TensorDomain::getContiguityFilledWith(out_domain, true));

return IrBuilder::create<TensorView>(td, input->dtype());
}

} // namespace

TensorView* linear(TensorView* input, TensorView* weight, TensorView* bias) {
auto input_ndims =
TensorDomain::noReductions(input->getMaybeRFactorDomain()).size();
NVF_CHECK(input_ndims > 0, "Input A must be atleast 1D.");

auto weight_ndims =
TensorDomain::noReductions(weight->getMaybeRFactorDomain()).size();
NVF_CHECK(
(a->nDims() == 2 && b->nDims() == 2),
"Only 2-D Inputs and Weights are currently supported in Linear!");

std::vector<bool> bcast_dims(a->nDims() + 1, false);
// A: [M, Bcast, K]
// B: [Bcast, N, K]
bcast_dims.at(bcast_dims.size() - 2) = true;
auto* tv0b = broadcast(a, bcast_dims);
bcast_dims.at(bcast_dims.size() - 2) = false;
bcast_dims.at(bcast_dims.size() - 3) = true;
auto* tv1b = broadcast(b, bcast_dims);
weight_ndims == 1 || weight_ndims == 2,
"Input B must be a 1D / 2D tensor.");

// Note: This constraint is not documented but F.linear errors out if bias is
// given with 1D weights.
NVF_CHECK(
a->getDataType().value() == b->getDataType().value(),
"data types of inputs to matmul don't match");

auto* output = fusedMultiplySum(tv0b, tv1b, {-1});
if (bias) {
NVF_CHECK(
(bias->nDims() <= a->nDims()), "bias should be broadcastable to A");
NVF_CHECK(
a->getDataType().value() == bias->getDataType().value(),
"bias doesn't match input/weight dtype");
auto* bias_with_cast = maybeCastOp(output->getDataType().value(), bias);
auto* bcast_bias = ops::maybeBroadcast({output, bias_with_cast})[1];
auto* bias_output = add(output, bcast_bias);
return maybeCastOp(a->getDataType().value(), bias_output);
}
return maybeCastOp(a->getDataType().value(), output);
weight_ndims == 2 || bias == nullptr,
"Expected B to be a 2D matrix if bias is given, got 1D.")

NVF_CHECK(
input->dtype() == weight->dtype(),
"Expected input and weight dtypes to have the same dtype, got: ",
input->dtype(),
" and ",
weight->dtype());

NVF_CHECK(
bias == nullptr || bias->dtype() == input->dtype(),
"Expected bias to have the same dtype as A and B, got: ",
bias->dtype(),
" and ",
input->dtype());
// For all other cases, create a new LinearOp
TensorView* out = newForLinear(input, weight, bias);
IrBuilder::create<LinearOp>(out, input, weight, bias);
return out;
}

TensorView* linear(TensorView* a, TensorView* b) {
return linear(a, b, nullptr /*bias*/);
TensorView* linear(TensorView* tv_a, TensorView* tv_b) {
return linear(tv_a, tv_b, /*bias=*/nullptr);
}

LstmResult lstm(
Expand Down Expand Up @@ -293,15 +343,8 @@ static TensorView* newForMatmul(TensorView* tv_a, TensorView* tv_b) {
orig_domain_b, MatmulRole::INPUT_B, ndims_out);

for (auto idx : c10::irange(ndims_out - 1)) {
std::vector<IterDomain*> input_ids;
input_ids.reserve(2);
if (mapping_a[idx] != nullptr) {
input_ids.emplace_back(mapping_a[idx]);
}
if (mapping_b[idx] != nullptr) {
input_ids.emplace_back(mapping_b[idx]);
}
out_domain[idx] = ops::newOutputIterDomain(input_ids);
out_domain[idx] =
ops::newOutputIterDomain({mapping_a.at(idx), mapping_b.at(idx)});
}

out_domain[ndims_out - 1] = ops::newOutputIterDomain(
Expand Down
12 changes: 6 additions & 6 deletions csrc/ops/composite.h
Original file line number Diff line number Diff line change
Expand Up @@ -47,17 +47,17 @@ NVF_API LstmResult lstm(
TensorView* cell_x,
TensorView* out_x);

// Linear functions which takes in two tensors of shapes A[M,K] and
// B[N,K]. Takes in a options bias of shape [N] and performs
// out = A * B_Transpose + bias. The output dtype matches the dtype
// ofthe inputs which should match.
TensorView* linear(TensorView* a, TensorView* b, TensorView* bias);
// Linear functions which takes in two tensors of shapes input[* , in_features],
// weight[out_features, in_features] / [in_features] and an optional bias of
// shape [out_features] or 0D scalar. Bias can only be given if weight is a 2-D
// tensor.
TensorView* linear(TensorView* input, TensorView* weight, TensorView* bias);
// This is an implementation detail to reflect when linear is called
// without a bias. This calls the above function. We use this function
// since it simplifies creating a Python API which takes optional arguments.
// Other options include using lambdas or creating a new RecordFunctor for
// Linear.
TensorView* linear(TensorView* a, TensorView* b);
TensorView* linear(TensorView* input, TensorView* weight);

NVF_API TensorView* sign(TensorView* x);
NVF_API Val* sign(Val* x);
Expand Down
51 changes: 50 additions & 1 deletion csrc/ops/utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -221,14 +221,53 @@ std::vector<IterDomain*> mapMatmulOpIterDomains(
return mapping;
}

std::vector<IterDomain*> mapLinearOpIterDomains(
const std::vector<IterDomain*>& input_domain,
MatmulRole input_role,
size_t out_size) {
std::vector<IterDomain*> mapping(out_size, nullptr);
auto inp_size = input_domain.size();

// Input A: {*, M, K}
// Input B: {*, N, K} / {K}
// Bias: {N} / {}
switch (input_role) {
case MatmulRole::INPUT_A: {
// Linear output is same as input for all but the last dimension
for (auto inx : c10::irange(inp_size - 1)) {
mapping[inx] = input_domain[inx];
}
mapping[out_size - 1] = input_domain.back();
break;
}
case MatmulRole::INPUT_B: {
for (auto inx : c10::irange(inp_size)) {
// Map N, K to the last two positions of the output.
mapping[out_size - 1 - inx] = input_domain[inp_size - 1 - inx];
}
break;
}
case MatmulRole::INPUT_C: {
if (inp_size > 0) {
// Bias is 1D tensor of shape {out_features}
mapping[out_size - 2] = input_domain[0];
}
break;
}
default:
NVF_ERROR("Unexpected input type.");
}
return mapping;
}

// Adding these pragmas since gcc-12.2.1
// incorrectly reports a warning with the use of evaluate
#if defined(__GNUC__) && !defined(__clang__)
#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wfree-nonheap-object"
#endif
IterDomain* newOutputIterDomain(
const std::vector<IterDomain*>& ids,
const std::vector<IterDomain*>& input_ids,
const std::optional<IterType> force_iter_type) {
// For the start and stop offsets, take the maximum of input axes.
// For now, the offsets of both start and stop are always integer
Expand All @@ -242,6 +281,16 @@ IterDomain* newOutputIterDomain(
Val* expanded_extent_val = nullptr;
std::optional<IterType> iter_type = std::nullopt;

std::vector<IterDomain*> ids;
ids.reserve(input_ids.size());

// Filter out any nullptrs
std::copy_if(
input_ids.begin(),
input_ids.end(),
std::back_inserter(ids),
[](IterDomain* id) { return id != nullptr; });

for (auto id : ids) {
if (id->isBroadcast()) {
if (id->hasExpandedExtent()) {
Expand Down
21 changes: 21 additions & 0 deletions csrc/ops/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -46,11 +46,32 @@ IterType promoteIterType(IterType type1, IterType type2);
// Mapping B: {nullptr, id_N})
// 3. A/B are atleast 1D and one of them is > 2D: [B, M, K] x [K, N] -> [B, M,
// N] (Mapping A: {id_B, id_M, nullptr}, Mapping B: {nullptr, nullptr, id_N})
// Args:
// 1. input_domain: root/rfactor domain without reductions for any input to
// MatmulOp
// 2. input_role: Specifies if the input is A / B (MatmulRole::Input_A/Input_B)
// 3: out_size: MatmulOp output dimension (input and output may not be the same
// size).
std::vector<IterDomain*> mapMatmulOpIterDomains(
const std::vector<IterDomain*>& input_domain,
MatmulRole input_role,
size_t out_size);

// For LinearOp, the output is the same as the first input (A[*,
// in_features])for all but the last dimension. If the second input is 2D
// (B[out_features, in_features]), the last dimension of output is out_features.
// If bias is 1D (bias[out_features]) it maps to the last dimension of the
// output. Args:
// 1. input_domain: root/rfactor domain without reductions for any input to
// LinearOp
// 2. input_role: Specifies if the input is A / B / Bias
// (MatmulRole::Input_A/Input_B/Input_C) 3: out_size: LinearOp output dimension
// (input and output may not be the same size).
std::vector<IterDomain*> mapLinearOpIterDomains(
const std::vector<IterDomain*>& input_domain,
MatmulRole input_role,
size_t out_size);

// Takes a vector of aligned input iterdomains to create the output iterdomain.
// This is used if the input iterdomains are not trivially mapped to the output
// iterdomains. For eg: MatmulOp. If given, the forced_iter_type argument will
Expand Down
Loading