Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 47 additions & 5 deletions include/NeuraDialect/NeuraOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ def Neura_FSubOp: Op<NeuraDialect, "fsub"> {
def Neura_FMulOp : Op<NeuraDialect, "fmul"> {
let summary = "Floating multiplication operation";
let opName = "fmul";
let arguments = (ins AnyType:$lhs, AnyType:$rhs);
let arguments = (ins AnyType:$lhs, Optional<AnyType>:$rhs);
let results = (outs AnyType:$result);
// let assemblyFormat = "$lhs `,` $rhs `,` attr-dict `:` type($result)";
// let traits = [SameOperandsAndResultElementType];
Expand All @@ -104,6 +104,48 @@ def Neura_FDivOp : Op<NeuraDialect, "fdiv"> {
let traits = [SameOperandsAndResultElementType];
}

// Defines a floating-point maximum operation.
def Neura_FMaxOp : Op<NeuraDialect, "fmax"> {
let summary = "Floating-point maximum operation";
let description = [{
Returns the maximum of two floating-point values.
Supports two NaN propagation semantics:
- "maxnum": Returns non-NaN value when one operand is NaN (llvm.maxnum)
- "maximum": Propagates NaN when any operand is NaN (llvm.maximum)

Example:
%result = neura.fmax<"maxnum">(%a, %b : f32) : f32 -> f32
%result = neura.fmax<"maximum">(%a, %b : f32) : f32 -> f32
}];
let opName = "fmax";
let arguments = (ins AnyType:$lhs, Optional<AnyType>:$rhs,
DefaultValuedAttr<StrAttr, "\"maxnum\"">:$nan_semantic);
let results = (outs AnyType:$result);
let traits = [SameOperandsAndResultElementType];
let assemblyFormat = "`<` $nan_semantic `>` `(` $lhs (`,` $rhs^ `:` type($rhs))? `)` attr-dict `:` type($lhs) `->` type($result)";
}

// Defines a floating-point minimum operation.
def Neura_FMinOp : Op<NeuraDialect, "fmin"> {
let summary = "Floating-point minimum operation";
let description = [{
Returns the minimum of two floating-point values.
Supports two NaN propagation semantics:
- "minnum": Returns non-NaN value when one operand is NaN (llvm.minnum)
- "minimum": Propagates NaN when any operand is NaN (llvm.minimum)

Example:
%result = neura.fmin<"minnum">(%a, %b : f32) : f32 -> f32
%result = neura.fmin<"minimum">(%a, %b : f32) : f32 -> f32
}];
let opName = "fmin";
let arguments = (ins AnyType:$lhs, Optional<AnyType>:$rhs,
DefaultValuedAttr<StrAttr, "\"minnum\"">:$nan_semantic);
let results = (outs AnyType:$result);
let traits = [SameOperandsAndResultElementType];
let assemblyFormat = "`<` $nan_semantic `>` `(` $lhs (`,` $rhs^ `:` type($rhs))? `)` attr-dict `:` type($lhs) `->` type($result)";
}

// Defines a bitwise OR operation.
def Neura_OrOp : Op<NeuraDialect, "or"> {
let summary = "Bitwise OR operation";
Expand Down Expand Up @@ -151,29 +193,29 @@ def Neura_StoreOp : Op<NeuraDialect, "store"> {
}

// Defines a load operation with integrated address calculation.
def Neura_LoadIndexedOp: Op<NeuraDialect, "load_indexed">{
def Neura_LoadIndexedOp: Op<NeuraDialect, "load_indexed", [AttrSizedOperandSegments]>{
let summary = "Load with integrated address calculation for multi-dimensional arrays";
let description = [{
Calculates the address using the base address and indices.
Load the value at the calculated address.
Example:
%value = neura.load_indexed %base [%arg1, %arg2] : f32
}];
let arguments = (ins AnyType:$base, Variadic<AnyType>:$indices);
let arguments = (ins Optional<AnyType>:$base, Variadic<AnyType>:$indices);
let results = (outs AnyType:$result);
let assemblyFormat = "$base `[` $indices `:` type($indices) `]` type($base) attr-dict `:` type($result)";
}

//Defines a store operation with integrated address calculation.
def Neura_StoreIndexedOp: Op<NeuraDialect, "store_indexed"> {
def Neura_StoreIndexedOp: Op<NeuraDialect, "store_indexed", [AttrSizedOperandSegments]> {
let summary = "Store with integrated address calculation for multi-dimensional arrays";
let description = [{
Calculates the address using the base address and indices.
Store the value at the calculated address.
Example:
neura.store_indexed %value, %base [%arg1, %arg2] : f32
}];
let arguments = (ins AnyType:$value, AnyType:$base, Variadic<AnyType>:$indices);
let arguments = (ins AnyType:$value, Optional<AnyType>:$base, Variadic<AnyType>:$indices);
let results = (outs);
let assemblyFormat = "$value `to` $base `[` $indices `:` type($indices) `]` type($base) attr-dict `:` type($value)";
}
Expand Down
80 changes: 80 additions & 0 deletions lib/Conversion/LlvmToNeura/LlvmToNeuraPass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,82 @@ struct LlvmSRemToNeuraRem : public OpRewritePattern<LLVM::SRemOp> {
}
};

struct LlvmMaxNumToNeuraFMax : public OpRewritePattern<LLVM::MaxNumOp> {
using OpRewritePattern::OpRewritePattern;

LogicalResult matchAndRewrite(LLVM::MaxNumOp op,
PatternRewriter &rewriter) const override {
Value lhs = op->getOperand(0);
Value rhs = op->getOperand(1);
Type resultType = op->getResult(0).getType();

// Only matches scalar float.
if (!mlir::isa<FloatType>(resultType))
return failure();

rewriter.replaceOpWithNewOp<neura::FMaxOp>(op, resultType, lhs, rhs,
rewriter.getStringAttr("maxnum"));
return success();
}
};

struct LlvmMaximumToNeuraFMax : public OpRewritePattern<LLVM::MaximumOp> {
using OpRewritePattern::OpRewritePattern;

LogicalResult matchAndRewrite(LLVM::MaximumOp op,
PatternRewriter &rewriter) const override {
Value lhs = op->getOperand(0);
Value rhs = op->getOperand(1);
Type resultType = op->getResult(0).getType();

// Only matches scalar float.
if (!mlir::isa<FloatType>(resultType))
return failure();

rewriter.replaceOpWithNewOp<neura::FMaxOp>(op, resultType, lhs, rhs,
rewriter.getStringAttr("maximum"));
return success();
}
};

struct LlvmMinNumToNeuraFMin : public OpRewritePattern<LLVM::MinNumOp> {
using OpRewritePattern::OpRewritePattern;

LogicalResult matchAndRewrite(LLVM::MinNumOp op,
PatternRewriter &rewriter) const override {
Value lhs = op->getOperand(0);
Value rhs = op->getOperand(1);
Type resultType = op->getResult(0).getType();

// Only matches scalar float.
if (!mlir::isa<FloatType>(resultType))
return failure();

rewriter.replaceOpWithNewOp<neura::FMinOp>(op, resultType, lhs, rhs,
rewriter.getStringAttr("minnum"));
return success();
}
};

struct LlvmMinimumToNeuraFMin : public OpRewritePattern<LLVM::MinimumOp> {
using OpRewritePattern::OpRewritePattern;

LogicalResult matchAndRewrite(LLVM::MinimumOp op,
PatternRewriter &rewriter) const override {
Value lhs = op->getOperand(0);
Value rhs = op->getOperand(1);
Type resultType = op->getResult(0).getType();

// Only matches scalar float.
if (!mlir::isa<FloatType>(resultType))
return failure();

rewriter.replaceOpWithNewOp<neura::FMinOp>(op, resultType, lhs, rhs,
rewriter.getStringAttr("minimum"));
return success();
}
};

struct LlvmFDivToNeuraFDiv : public OpRewritePattern<mlir::LLVM::FDivOp> {
using OpRewritePattern::OpRewritePattern;

Expand Down Expand Up @@ -585,6 +661,10 @@ struct LowerLlvmToNeuraPass
patterns.add<LlvmShlToNeuraShl>(&getContext());
patterns.add<LlvmSDivToNeuraDiv>(&getContext());
patterns.add<LlvmSRemToNeuraRem>(&getContext());
patterns.add<LlvmMaxNumToNeuraFMax>(&getContext());
patterns.add<LlvmMaximumToNeuraFMax>(&getContext());
patterns.add<LlvmMinNumToNeuraFMin>(&getContext());
patterns.add<LlvmMinimumToNeuraFMin>(&getContext());
patterns.add<LlvmFDivToNeuraFDiv>(&getContext());
patterns.add<LlvmFPToSIToNeuraCast>(&getContext());
patterns.add<LlvmFMulAddToNeuraFMulFAdd>(&getContext());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -239,6 +239,76 @@ struct FuseFAddRhsConstantPattern
}
};

struct FuseFSubRhsConstantPattern
: public FuseRhsConstantPattern<neura::FSubOp> {
using FuseRhsConstantPattern<neura::FSubOp>::FuseRhsConstantPattern;

Operation *
createOpWithFusedRhsConstant(neura::FSubOp op, Value non_const_operand,
Attribute rhs_value,
PatternRewriter &rewriter) const override {
auto fused_op = rewriter.create<neura::FSubOp>(
op.getLoc(), op.getResult().getType(), non_const_operand,
/*rhs=*/nullptr);
addConstantAttribute(fused_op, "rhs_value", rhs_value);
return fused_op;
}
};

struct FuseFMulRhsConstantPattern
: public FuseRhsConstantPattern<neura::FMulOp> {
using FuseRhsConstantPattern<neura::FMulOp>::FuseRhsConstantPattern;

bool isCommutative() const override { return true; }

Operation *
createOpWithFusedRhsConstant(neura::FMulOp op, Value non_const_operand,
Attribute rhs_value,
PatternRewriter &rewriter) const override {
auto fused_op = rewriter.create<neura::FMulOp>(
op.getLoc(), op.getResult().getType(), non_const_operand,
/*rhs=*/nullptr);
addConstantAttribute(fused_op, "rhs_value", rhs_value);
return fused_op;
}
};

struct FuseFMaxRhsConstantPattern
: public FuseRhsConstantPattern<neura::FMaxOp> {
using FuseRhsConstantPattern<neura::FMaxOp>::FuseRhsConstantPattern;

bool isCommutative() const override { return true; }

Operation *
createOpWithFusedRhsConstant(neura::FMaxOp op, Value non_const_operand,
Attribute rhs_value,
PatternRewriter &rewriter) const override {
auto fused_op = rewriter.create<neura::FMaxOp>(
op.getLoc(), op.getResult().getType(), non_const_operand,
/*rhs=*/nullptr, op.getNanSemantic());
addConstantAttribute(fused_op, "rhs_value", rhs_value);
return fused_op;
}
};

struct FuseFMinRhsConstantPattern
: public FuseRhsConstantPattern<neura::FMinOp> {
using FuseRhsConstantPattern<neura::FMinOp>::FuseRhsConstantPattern;

bool isCommutative() const override { return true; }

Operation *
createOpWithFusedRhsConstant(neura::FMinOp op, Value non_const_operand,
Attribute rhs_value,
PatternRewriter &rewriter) const override {
auto fused_op = rewriter.create<neura::FMinOp>(
op.getLoc(), op.getResult().getType(), non_const_operand,
/*rhs=*/nullptr, op.getNanSemantic());
addConstantAttribute(fused_op, "rhs_value", rhs_value);
return fused_op;
}
};

struct FuseDivRhsConstantPattern : public FuseRhsConstantPattern<neura::DivOp> {
using FuseRhsConstantPattern<neura::DivOp>::FuseRhsConstantPattern;

Expand Down Expand Up @@ -353,6 +423,98 @@ struct FuseStoreAddrConstantPattern : public OpRewritePattern<neura::StoreOp> {
}
};

// =========================================
// FuseLoadIndexedBaseConstantPattern
// Folds constant base pointer for LoadIndexed operation.
// =========================================
struct FuseLoadIndexedBaseConstantPattern
: public OpRewritePattern<neura::LoadIndexedOp> {
using OpRewritePattern<neura::LoadIndexedOp>::OpRewritePattern;

LogicalResult matchAndRewrite(neura::LoadIndexedOp load_indexed_op,
PatternRewriter &rewriter) const override {
Value base = load_indexed_op.getBase();

// Checks if base exists and is a constant.
if (!base || !isOriginConstantOp(base)) {
return failure();
}

auto constant_op = dyn_cast<neura::ConstantOp>(base.getDefiningOp());
Attribute base_const_value = getOriginConstantValue(base);

// Gets all indices.
SmallVector<Value> indices;
for (Value idx : load_indexed_op.getIndices()) {
indices.push_back(idx);
}

// Creates new LoadIndexed with no base but with lhs_value attribute.
auto fused_load_indexed = rewriter.create<neura::LoadIndexedOp>(
load_indexed_op.getLoc(),
load_indexed_op.getResult().getType(),
/*base=*/nullptr,
indices);
addConstantAttribute(fused_load_indexed, "lhs_value", base_const_value);

// Replaces the original LoadIndexed.
rewriter.replaceOp(load_indexed_op, fused_load_indexed);

// Cleans up constant if no longer used.
if (constant_op->use_empty()) {
rewriter.eraseOp(constant_op);
}

return success();
}
};

// =========================================
// FuseStoreIndexedBaseConstantPattern
// Folds constant base pointer for StoreIndexed operation.
// =========================================
struct FuseStoreIndexedBaseConstantPattern
: public OpRewritePattern<neura::StoreIndexedOp> {
using OpRewritePattern<neura::StoreIndexedOp>::OpRewritePattern;

LogicalResult matchAndRewrite(neura::StoreIndexedOp store_indexed_op,
PatternRewriter &rewriter) const override {
Value base = store_indexed_op.getBase();

// Checks if base exists and is a constant.
if (!base || !isOriginConstantOp(base)) {
return failure();
}

auto constant_op = dyn_cast<neura::ConstantOp>(base.getDefiningOp());
Attribute base_const_value = getOriginConstantValue(base);

// Gets all indices.
SmallVector<Value> indices;
for (Value idx : store_indexed_op.getIndices()) {
indices.push_back(idx);
}

// Creates new StoreIndexed with no base but with rhs_value attribute.
auto fused_store_indexed = rewriter.create<neura::StoreIndexedOp>(
store_indexed_op.getLoc(),
store_indexed_op.getValue(), // Keeps the value operand.
/*base=*/nullptr,
indices);
addConstantAttribute(fused_store_indexed, "rhs_value", base_const_value);

// Replaces the original StoreIndexed.
rewriter.replaceOp(store_indexed_op, fused_store_indexed);

// Cleans up constant if no longer used.
if (constant_op->use_empty()) {
rewriter.eraseOp(constant_op);
}

return success();
}
};

// =========================================
// FoldConstantPass Implementation
// =========================================
Expand All @@ -374,12 +536,18 @@ struct FoldConstantPass
patterns.add<FuseMulRhsConstantPattern>(&getContext());
patterns.add<FuseICmpRhsConstantPattern>(&getContext());
patterns.add<FuseFAddRhsConstantPattern>(&getContext());
patterns.add<FuseFSubRhsConstantPattern>(&getContext());
patterns.add<FuseFMulRhsConstantPattern>(&getContext());
patterns.add<FuseFMaxRhsConstantPattern>(&getContext());
patterns.add<FuseFMinRhsConstantPattern>(&getContext());
patterns.add<FuseDivRhsConstantPattern>(&getContext());
patterns.add<FuseRemRhsConstantPattern>(&getContext());

patterns.add<FuseConstantAndGrantPattern>(&getContext());
patterns.add<FuseGepBaseConstantPattern>(&getContext());
patterns.add<FuseStoreAddrConstantPattern>(&getContext());
patterns.add<FuseLoadIndexedBaseConstantPattern>(&getContext());
patterns.add<FuseStoreIndexedBaseConstantPattern>(&getContext());
FrozenRewritePatternSet frozen(std::move(patterns));

// Applies to every region inside the module (regardless of func type,
Expand Down
2 changes: 1 addition & 1 deletion test/controflow_fuse/perfect_nested/perfect_nested.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -195,4 +195,4 @@ module attributes {} {
// CTRL2DATA-NEXT: }


// MAPPING: func.func @_Z10bert_node1PA1_A1_A1_A1_A128_bPA1_A128_S1_(%arg0: memref<?x1x1x1x1x128xi8>, %arg1: memref<?x1x128x1x1x128xi8>) attributes {accelerator = "neura", dataflow_mode = "predicate", llvm.linkage = #llvm.linkage<external>, mapping_info = {compiled_ii = 10 : i32, mapping_mode = "spatial-temporal", mapping_strategy = "heuristic", rec_mii = 8 : i32, res_mii = 3 : i32, x_tiles = 4 : i32, y_tiles = 4 : i32}} {
// MAPPING: func.func @_Z10bert_node1PA1_A1_A1_A1_A128_bPA1_A128_S1_(%arg0: memref<?x1x1x1x1x128xi8>, %arg1: memref<?x1x128x1x1x128xi8>) attributes {accelerator = "neura", dataflow_mode = "predicate", llvm.linkage = #llvm.linkage<external>, mapping_info = {compiled_ii = 10 : i32, mapping_mode = "spatial-temporal", mapping_strategy = "heuristic", rec_mii = 8 : i32, res_mii = 2 : i32, x_tiles = 4 : i32, y_tiles = 4 : i32}} {
Loading