From cf4175771fd86bd68e0365ff2a0488cf0467bb69 Mon Sep 17 00:00:00 2001 From: tangyz <739245980@qq.com> Date: Tue, 21 Oct 2025 00:50:05 +0800 Subject: [PATCH 1/4] Add constant folding support for FMul, FSub, FMax, FMin, LoadIndexed, StoreIndexed --- include/NeuraDialect/NeuraOps.td | 42 ++++- .../LlvmToNeura/LlvmToNeuraPass.cpp | 38 ++++ .../HwAgnosticOpt/FoldConstantPass.cpp | 168 ++++++++++++++++++ 3 files changed, 243 insertions(+), 5 deletions(-) diff --git a/include/NeuraDialect/NeuraOps.td b/include/NeuraDialect/NeuraOps.td index 4ead5f21..cedf3d09 100644 --- a/include/NeuraDialect/NeuraOps.td +++ b/include/NeuraDialect/NeuraOps.td @@ -84,7 +84,7 @@ def Neura_FSubOp: Op { def Neura_FMulOp : Op { let summary = "Floating multiplication operation"; let opName = "fmul"; - let arguments = (ins AnyType:$lhs, AnyType:$rhs); + let arguments = (ins AnyType:$lhs, Optional:$rhs); let results = (outs AnyType:$result); // let assemblyFormat = "$lhs `,` $rhs `,` attr-dict `:` type($result)"; // let traits = [SameOperandsAndResultElementType]; @@ -104,6 +104,38 @@ def Neura_FDivOp : Op { let traits = [SameOperandsAndResultElementType]; } +// Defines a floating-point maximum operation. +def Neura_FMaxOp : Op { + let summary = "Floating-point maximum operation"; + let description = [{ + Returns the maximum of two floating-point values. + Corresponds to llvm.maxnum intrinsic. + + Example: + %result = neura.fmax %a, %b : f32 + }]; + let opName = "fmax"; + let arguments = (ins AnyType:$lhs, Optional:$rhs); + let results = (outs AnyType:$result); + let traits = [SameOperandsAndResultElementType]; +} + +// Defines a floating-point minimum operation. +def Neura_FMinOp : Op { + let summary = "Floating-point minimum operation"; + let description = [{ + Returns the minimum of two floating-point values. + Corresponds to llvm.minnum intrinsic. + + Example: + %result = neura.fmin %a, %b : f32 + }]; + let opName = "fmin"; + let arguments = (ins AnyType:$lhs, Optional:$rhs); + let results = (outs AnyType:$result); + let traits = [SameOperandsAndResultElementType]; +} + // Defines a bitwise OR operation. def Neura_OrOp : Op { let summary = "Bitwise OR operation"; @@ -151,7 +183,7 @@ def Neura_StoreOp : Op { } // Defines a load operation with integrated address calculation. -def Neura_LoadIndexedOp: Op{ +def Neura_LoadIndexedOp: Op{ let summary = "Load with integrated address calculation for multi-dimensional arrays"; let description = [{ Calculates the address using the base address and indices. @@ -159,13 +191,13 @@ def Neura_LoadIndexedOp: Op{ Example: %value = neura.load_indexed %base [%arg1, %arg2] : f32 }]; - let arguments = (ins AnyType:$base, Variadic:$indices); + let arguments = (ins Optional:$base, Variadic:$indices); let results = (outs AnyType:$result); let assemblyFormat = "$base `[` $indices `:` type($indices) `]` type($base) attr-dict `:` type($result)"; } //Defines a store operation with integrated address calculation. -def Neura_StoreIndexedOp: Op { +def Neura_StoreIndexedOp: Op { let summary = "Store with integrated address calculation for multi-dimensional arrays"; let description = [{ Calculates the address using the base address and indices. @@ -173,7 +205,7 @@ def Neura_StoreIndexedOp: Op { Example: neura.store_indexed %value, %base [%arg1, %arg2] : f32 }]; - let arguments = (ins AnyType:$value, AnyType:$base, Variadic:$indices); + let arguments = (ins AnyType:$value, Optional:$base, Variadic:$indices); let results = (outs); let assemblyFormat = "$value `to` $base `[` $indices `:` type($indices) `]` type($base) attr-dict `:` type($value)"; } diff --git a/lib/Conversion/LlvmToNeura/LlvmToNeuraPass.cpp b/lib/Conversion/LlvmToNeura/LlvmToNeuraPass.cpp index f0f34821..b24523b0 100644 --- a/lib/Conversion/LlvmToNeura/LlvmToNeuraPass.cpp +++ b/lib/Conversion/LlvmToNeura/LlvmToNeuraPass.cpp @@ -134,6 +134,42 @@ struct LlvmSRemToNeuraRem : public OpRewritePattern { } }; +struct LlvmMaxNumToNeuraFMax : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(LLVM::MaxNumOp op, + PatternRewriter &rewriter) const override { + Value lhs = op->getOperand(0); + Value rhs = op->getOperand(1); + Type resultType = op->getResult(0).getType(); + + // Only matches scalar float. + if (!mlir::isa(resultType)) + return failure(); + + rewriter.replaceOpWithNewOp(op, resultType, lhs, rhs); + return success(); + } +}; + +struct LlvmMinNumToNeuraFMin : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(LLVM::MinNumOp op, + PatternRewriter &rewriter) const override { + Value lhs = op->getOperand(0); + Value rhs = op->getOperand(1); + Type resultType = op->getResult(0).getType(); + + // Only matches scalar float. + if (!mlir::isa(resultType)) + return failure(); + + rewriter.replaceOpWithNewOp(op, resultType, lhs, rhs); + return success(); + } +}; + struct LlvmFDivToNeuraFDiv : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; @@ -585,6 +621,8 @@ struct LowerLlvmToNeuraPass patterns.add(&getContext()); patterns.add(&getContext()); patterns.add(&getContext()); + patterns.add(&getContext()); + patterns.add(&getContext()); patterns.add(&getContext()); patterns.add(&getContext()); patterns.add(&getContext()); diff --git a/lib/NeuraDialect/Transforms/Optimizations/HwAgnosticOpt/FoldConstantPass.cpp b/lib/NeuraDialect/Transforms/Optimizations/HwAgnosticOpt/FoldConstantPass.cpp index 88848a46..86a86264 100644 --- a/lib/NeuraDialect/Transforms/Optimizations/HwAgnosticOpt/FoldConstantPass.cpp +++ b/lib/NeuraDialect/Transforms/Optimizations/HwAgnosticOpt/FoldConstantPass.cpp @@ -239,6 +239,76 @@ struct FuseFAddRhsConstantPattern } }; +struct FuseFSubRhsConstantPattern + : public FuseRhsConstantPattern { + using FuseRhsConstantPattern::FuseRhsConstantPattern; + + Operation * + createOpWithFusedRhsConstant(neura::FSubOp op, Value non_const_operand, + Attribute rhs_value, + PatternRewriter &rewriter) const override { + auto fused_op = rewriter.create( + op.getLoc(), op.getResult().getType(), non_const_operand, + /*rhs=*/nullptr); + addConstantAttribute(fused_op, "rhs_value", rhs_value); + return fused_op; + } +}; + +struct FuseFMulRhsConstantPattern + : public FuseRhsConstantPattern { + using FuseRhsConstantPattern::FuseRhsConstantPattern; + + bool isCommutative() const override { return true; } + + Operation * + createOpWithFusedRhsConstant(neura::FMulOp op, Value non_const_operand, + Attribute rhs_value, + PatternRewriter &rewriter) const override { + auto fused_op = rewriter.create( + op.getLoc(), op.getResult().getType(), non_const_operand, + /*rhs=*/nullptr); + addConstantAttribute(fused_op, "rhs_value", rhs_value); + return fused_op; + } +}; + +struct FuseFMaxRhsConstantPattern + : public FuseRhsConstantPattern { + using FuseRhsConstantPattern::FuseRhsConstantPattern; + + bool isCommutative() const override { return true; } + + Operation * + createOpWithFusedRhsConstant(neura::FMaxOp op, Value non_const_operand, + Attribute rhs_value, + PatternRewriter &rewriter) const override { + auto fused_op = rewriter.create( + op.getLoc(), op.getResult().getType(), non_const_operand, + /*rhs=*/nullptr); + addConstantAttribute(fused_op, "rhs_value", rhs_value); + return fused_op; + } +}; + +struct FuseFMinRhsConstantPattern + : public FuseRhsConstantPattern { + using FuseRhsConstantPattern::FuseRhsConstantPattern; + + bool isCommutative() const override { return true; } + + Operation * + createOpWithFusedRhsConstant(neura::FMinOp op, Value non_const_operand, + Attribute rhs_value, + PatternRewriter &rewriter) const override { + auto fused_op = rewriter.create( + op.getLoc(), op.getResult().getType(), non_const_operand, + /*rhs=*/nullptr); + addConstantAttribute(fused_op, "rhs_value", rhs_value); + return fused_op; + } +}; + struct FuseDivRhsConstantPattern : public FuseRhsConstantPattern { using FuseRhsConstantPattern::FuseRhsConstantPattern; @@ -353,6 +423,98 @@ struct FuseStoreAddrConstantPattern : public OpRewritePattern { } }; +// ========================================= +// FuseLoadIndexedBaseConstantPattern +// Folds constant base pointer for LoadIndexed operation. +// ========================================= +struct FuseLoadIndexedBaseConstantPattern + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(neura::LoadIndexedOp load_indexed_op, + PatternRewriter &rewriter) const override { + Value base = load_indexed_op.getBase(); + + // Checks if base exists and is a constant. + if (!base || !isOriginConstantOp(base)) { + return failure(); + } + + auto constant_op = dyn_cast(base.getDefiningOp()); + Attribute base_const_value = getOriginConstantValue(base); + + // Gets all indices. + SmallVector indices; + for (Value idx : load_indexed_op.getIndices()) { + indices.push_back(idx); + } + + // Creates new LoadIndexed with no base but with lhs_value attribute. + auto fused_load_indexed = rewriter.create( + load_indexed_op.getLoc(), + load_indexed_op.getResult().getType(), + /*base=*/nullptr, + indices); + addConstantAttribute(fused_load_indexed, "lhs_value", base_const_value); + + // Replaces the original LoadIndexed. + rewriter.replaceOp(load_indexed_op, fused_load_indexed); + + // Cleans up constant if no longer used. + if (constant_op->use_empty()) { + rewriter.eraseOp(constant_op); + } + + return success(); + } +}; + +// ========================================= +// FuseStoreIndexedBaseConstantPattern +// Folds constant base pointer for StoreIndexed operation. +// ========================================= +struct FuseStoreIndexedBaseConstantPattern + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(neura::StoreIndexedOp store_indexed_op, + PatternRewriter &rewriter) const override { + Value base = store_indexed_op.getBase(); + + // Checks if base exists and is a constant. + if (!base || !isOriginConstantOp(base)) { + return failure(); + } + + auto constant_op = dyn_cast(base.getDefiningOp()); + Attribute base_const_value = getOriginConstantValue(base); + + // Gets all indices. + SmallVector indices; + for (Value idx : store_indexed_op.getIndices()) { + indices.push_back(idx); + } + + // Creates new StoreIndexed with no base but with rhs_value attribute. + auto fused_store_indexed = rewriter.create( + store_indexed_op.getLoc(), + store_indexed_op.getValue(), // Keeps the value operand. + /*base=*/nullptr, + indices); + addConstantAttribute(fused_store_indexed, "rhs_value", base_const_value); + + // Replaces the original StoreIndexed. + rewriter.replaceOp(store_indexed_op, fused_store_indexed); + + // Cleans up constant if no longer used. + if (constant_op->use_empty()) { + rewriter.eraseOp(constant_op); + } + + return success(); + } +}; + // ========================================= // FoldConstantPass Implementation // ========================================= @@ -374,12 +536,18 @@ struct FoldConstantPass patterns.add(&getContext()); patterns.add(&getContext()); patterns.add(&getContext()); + patterns.add(&getContext()); + patterns.add(&getContext()); + patterns.add(&getContext()); + patterns.add(&getContext()); patterns.add(&getContext()); patterns.add(&getContext()); patterns.add(&getContext()); patterns.add(&getContext()); patterns.add(&getContext()); + patterns.add(&getContext()); + patterns.add(&getContext()); FrozenRewritePatternSet frozen(std::move(patterns)); // Applies to every region inside the module (regardless of func type, From 7406de3b61c154c34caf70aad24af9c7a6f6cc2c Mon Sep 17 00:00:00 2001 From: tangyz <739245980@qq.com> Date: Tue, 21 Oct 2025 02:13:40 +0800 Subject: [PATCH 2/4] Add FMax and FMin operations support in neura-interpreter and update test cases --- .../perfect_nested/perfect_nested.mlir | 2 +- .../simple_loop/simple_loop.mlir | 24 +-- .../interpreter/basic_operation/fmax.mlir | 68 ++++++++ .../interpreter/basic_operation/fmin.mlir | 68 ++++++++ .../constant_folding/simple_loop.mlir | 60 ++++--- tools/neura-interpreter/neura-interpreter.cpp | 152 ++++++++++++++++++ 6 files changed, 334 insertions(+), 40 deletions(-) create mode 100644 test/neura/interpreter/basic_operation/fmax.mlir create mode 100644 test/neura/interpreter/basic_operation/fmin.mlir diff --git a/test/controflow_fuse/perfect_nested/perfect_nested.mlir b/test/controflow_fuse/perfect_nested/perfect_nested.mlir index 93cb29ac..0b565887 100644 --- a/test/controflow_fuse/perfect_nested/perfect_nested.mlir +++ b/test/controflow_fuse/perfect_nested/perfect_nested.mlir @@ -195,4 +195,4 @@ module attributes {} { // CTRL2DATA-NEXT: } -// MAPPING: func.func @_Z10bert_node1PA1_A1_A1_A1_A128_bPA1_A128_S1_(%arg0: memref, %arg1: memref) attributes {accelerator = "neura", dataflow_mode = "predicate", llvm.linkage = #llvm.linkage, mapping_info = {compiled_ii = 10 : i32, mapping_mode = "spatial-temporal", mapping_strategy = "heuristic", rec_mii = 8 : i32, res_mii = 3 : i32, x_tiles = 4 : i32, y_tiles = 4 : i32}} { \ No newline at end of file +// MAPPING: func.func @_Z10bert_node1PA1_A1_A1_A1_A128_bPA1_A128_S1_(%arg0: memref, %arg1: memref) attributes {accelerator = "neura", dataflow_mode = "predicate", llvm.linkage = #llvm.linkage, mapping_info = {compiled_ii = 10 : i32, mapping_mode = "spatial-temporal", mapping_strategy = "heuristic", rec_mii = 8 : i32, res_mii = 2 : i32, x_tiles = 4 : i32, y_tiles = 4 : i32}} { \ No newline at end of file diff --git a/test/controflow_fuse/simple_loop/simple_loop.mlir b/test/controflow_fuse/simple_loop/simple_loop.mlir index a05ad8a9..43976f4d 100644 --- a/test/controflow_fuse/simple_loop/simple_loop.mlir +++ b/test/controflow_fuse/simple_loop/simple_loop.mlir @@ -185,23 +185,13 @@ module attributes {} { // FUSE: func.func @_Z11simple_loopPiS_(%arg0: memref, %arg1: memref) attributes {accelerator = "neura", dataflow_mode = "predicate", llvm.linkage = #llvm.linkage} { -// FUSE-NEXT: %0 = "neura.grant_once"() <{constant_value = "%arg0"}> : () -> !neura.data, i1> -// FUSE-NEXT: %1 = "neura.grant_once"() <{constant_value = "%arg1"}> : () -> !neura.data, i1> -// FUSE-NEXT: %2 = neura.reserve : !neura.data, i1> -// FUSE-NEXT: %3 = "neura.phi"(%2, %1) : (!neura.data, i1>, !neura.data, i1>) -> !neura.data, i1> -// FUSE-NEXT: %4 = neura.reserve : !neura.data, i1> -// FUSE-NEXT: %5 = "neura.phi"(%4, %0) : (!neura.data, i1>, !neura.data, i1>) -> !neura.data, i1> -// FUSE-NEXT: %6 = "neura.grant_always"() <{constant_value = true}> : () -> !neura.data -// FUSE-NEXT: %nextindex, %valid = "neura.loop_control"(%6) <{end = 128 : i64, iterationType = "increment", start = 0 : i64, step = 1 : i64}> : (!neura.data) -> (!neura.data, !neura.data) -// FUSE-NEXT: %7 = neura.grant_predicate %5, %valid : !neura.data, i1>, !neura.data -> !neura.data, i1> -// FUSE-NEXT: %8 = neura.grant_predicate %3, %valid : !neura.data, i1>, !neura.data -> !neura.data, i1> -// FUSE-NEXT: %9 = neura.load_indexed %7[%nextindex : !neura.data] !neura.data, i1> : !neura.data -// FUSE-NEXT: %10 = "neura.mul"(%9) {rhs_value = 2 : i32} : (!neura.data) -> !neura.data -// FUSE-NEXT: %11 = "neura.add"(%10) {rhs_value = 1 : i32} : (!neura.data) -> !neura.data -// FUSE-NEXT: neura.store_indexed %11 to %8[%nextindex : !neura.data] !neura.data, i1> : !neura.data -// FUSE-NEXT: neura.ctrl_mov %7 -> %4 : !neura.data, i1> !neura.data, i1> -// FUSE-NEXT: neura.ctrl_mov %8 -> %2 : !neura.data, i1> !neura.data, i1> +// FUSE-NEXT: %0 = "neura.grant_always"() <{constant_value = true}> : () -> !neura.data +// FUSE-NEXT: %nextindex, %valid = "neura.loop_control"(%0) <{end = 128 : i64, iterationType = "increment", start = 0 : i64, step = 1 : i64}> : (!neura.data) -> (!neura.data, !neura.data) +// FUSE-NEXT: %1 = neura.load_indexed [%nextindex : !neura.data] {lhs_value = "%arg0"} : !neura.data +// FUSE-NEXT: %2 = "neura.mul"(%1) {rhs_value = 2 : i32} : (!neura.data) -> !neura.data +// FUSE-NEXT: %3 = "neura.add"(%2) {rhs_value = 1 : i32} : (!neura.data) -> !neura.data +// FUSE-NEXT: neura.store_indexed %3 to [%nextindex : !neura.data] {rhs_value = "%arg1"} : !neura.data // FUSE-NEXT: "neura.return"() : () -> () // FUSE-NEXT: } -// FUSE-MAPPING: func.func @_Z11simple_loopPiS_(%arg0: memref, %arg1: memref) attributes {accelerator = "neura", dataflow_mode = "predicate", llvm.linkage = #llvm.linkage, mapping_info = {compiled_ii = 2 : i32, mapping_mode = "spatial-temporal", mapping_strategy = "heuristic", rec_mii = 2 : i32, res_mii = 1 : i32, x_tiles = 4 : i32, y_tiles = 4 : i32}} { \ No newline at end of file +// FUSE-MAPPING: func.func @_Z11simple_loopPiS_(%arg0: memref, %arg1: memref) attributes {accelerator = "neura", dataflow_mode = "predicate", llvm.linkage = #llvm.linkage, mapping_info = {compiled_ii = 1 : i32, mapping_mode = "spatial-temporal", mapping_strategy = "heuristic", rec_mii = 1 : i32, res_mii = 1 : i32, x_tiles = 4 : i32, y_tiles = 4 : i32}} { \ No newline at end of file diff --git a/test/neura/interpreter/basic_operation/fmax.mlir b/test/neura/interpreter/basic_operation/fmax.mlir new file mode 100644 index 00000000..9e86374b --- /dev/null +++ b/test/neura/interpreter/basic_operation/fmax.mlir @@ -0,0 +1,68 @@ +// RUN: neura-interpreter %s --verbose | FileCheck %s + +// ===----------------------------------------------------------------------===// +// Test 1: Max of two positive floats +// ===----------------------------------------------------------------------===// +func.func @test_fmax_positive() -> f32 { + %a = arith.constant 10.0 : f32 + %b = arith.constant 32.0 : f32 + %res = "neura.fmax"(%a, %b) : (f32, f32) -> f32 + // CHECK: [neura-interpreter] → Output: 32.000000 + return %res : f32 +} + +// ===----------------------------------------------------------------------===// +// Test 2: Max of negative and positive float +// ===----------------------------------------------------------------------===// +func.func @test_fmax_mixed() -> f32 { + %a = arith.constant -5.0 : f32 + %b = arith.constant 3.0 : f32 + %res = "neura.fmax"(%a, %b) : (f32, f32) -> f32 + // CHECK: [neura-interpreter] → Output: 3.000000 + return %res : f32 +} + +// ===----------------------------------------------------------------------===// +// Test 3: Max of two negative floats +// ===----------------------------------------------------------------------===// +func.func @test_fmax_negative() -> f32 { + %a = arith.constant -10.0 : f32 + %b = arith.constant -3.0 : f32 + %res = "neura.fmax"(%a, %b) : (f32, f32) -> f32 + // CHECK: [neura-interpreter] → Output: -3.000000 + return %res : f32 +} + +// ===----------------------------------------------------------------------===// +// Test 4: Max with zero +// ===----------------------------------------------------------------------===// +func.func @test_fmax_zero() -> f32 { + %a = arith.constant 0.0 : f32 + %b = arith.constant -7.0 : f32 + %res = "neura.fmax"(%a, %b) : (f32, f32) -> f32 + // CHECK: [neura-interpreter] → Output: 0.000000 + return %res : f32 +} + +// ===----------------------------------------------------------------------===// +// Test 5: Max of equal values +// ===----------------------------------------------------------------------===// +func.func @test_fmax_equal() -> f32 { + %a = arith.constant 5.5 : f32 + %b = arith.constant 5.5 : f32 + %res = "neura.fmax"(%a, %b) : (f32, f32) -> f32 + // CHECK: [neura-interpreter] → Output: 5.500000 + return %res : f32 +} + +// ===----------------------------------------------------------------------===// +// Test 6: Max of fractional values +// ===----------------------------------------------------------------------===// +func.func @test_fmax_fraction() -> f32 { + %a = arith.constant 2.75 : f32 + %b = arith.constant 2.5 : f32 + %res = "neura.fmax"(%a, %b) : (f32, f32) -> f32 + // CHECK: [neura-interpreter] → Output: 2.750000 + return %res : f32 +} + diff --git a/test/neura/interpreter/basic_operation/fmin.mlir b/test/neura/interpreter/basic_operation/fmin.mlir new file mode 100644 index 00000000..afba4bb9 --- /dev/null +++ b/test/neura/interpreter/basic_operation/fmin.mlir @@ -0,0 +1,68 @@ +// RUN: neura-interpreter %s --verbose | FileCheck %s + +// ===----------------------------------------------------------------------===// +// Test 1: Min of two positive floats +// ===----------------------------------------------------------------------===// +func.func @test_fmin_positive() -> f32 { + %a = arith.constant 10.0 : f32 + %b = arith.constant 32.0 : f32 + %res = "neura.fmin"(%a, %b) : (f32, f32) -> f32 + // CHECK: [neura-interpreter] → Output: 10.000000 + return %res : f32 +} + +// ===----------------------------------------------------------------------===// +// Test 2: Min of negative and positive float +// ===----------------------------------------------------------------------===// +func.func @test_fmin_mixed() -> f32 { + %a = arith.constant -5.0 : f32 + %b = arith.constant 3.0 : f32 + %res = "neura.fmin"(%a, %b) : (f32, f32) -> f32 + // CHECK: [neura-interpreter] → Output: -5.000000 + return %res : f32 +} + +// ===----------------------------------------------------------------------===// +// Test 3: Min of two negative floats +// ===----------------------------------------------------------------------===// +func.func @test_fmin_negative() -> f32 { + %a = arith.constant -10.0 : f32 + %b = arith.constant -3.0 : f32 + %res = "neura.fmin"(%a, %b) : (f32, f32) -> f32 + // CHECK: [neura-interpreter] → Output: -10.000000 + return %res : f32 +} + +// ===----------------------------------------------------------------------===// +// Test 4: Min with zero +// ===----------------------------------------------------------------------===// +func.func @test_fmin_zero() -> f32 { + %a = arith.constant 0.0 : f32 + %b = arith.constant 7.0 : f32 + %res = "neura.fmin"(%a, %b) : (f32, f32) -> f32 + // CHECK: [neura-interpreter] → Output: 0.000000 + return %res : f32 +} + +// ===----------------------------------------------------------------------===// +// Test 5: Min of equal values +// ===----------------------------------------------------------------------===// +func.func @test_fmin_equal() -> f32 { + %a = arith.constant 5.5 : f32 + %b = arith.constant 5.5 : f32 + %res = "neura.fmin"(%a, %b) : (f32, f32) -> f32 + // CHECK: [neura-interpreter] → Output: 5.500000 + return %res : f32 +} + +// ===----------------------------------------------------------------------===// +// Test 6: Min of fractional values +// ===----------------------------------------------------------------------===// +func.func @test_fmin_fraction() -> f32 { + %a = arith.constant 2.75 : f32 + %b = arith.constant 2.5 : f32 + %res = "neura.fmin"(%a, %b) : (f32, f32) -> f32 + // CHECK: [neura-interpreter] → Output: 2.500000 + return %res : f32 +} + diff --git a/test/optimization/constant_folding/simple_loop.mlir b/test/optimization/constant_folding/simple_loop.mlir index 19aa172e..588fa494 100644 --- a/test/optimization/constant_folding/simple_loop.mlir +++ b/test/optimization/constant_folding/simple_loop.mlir @@ -1,4 +1,5 @@ // RUN: mlir-neura-opt %s \ +// RUN: --promote-func-arg-to-const \ // RUN: --fold-constant \ // RUN: | FileCheck %s -check-prefix=FOLD @@ -11,36 +12,51 @@ module { %4 = "neura.constant"() <{value = 1 : i32}> : () -> i32 %5 = "neura.constant"() <{value = 2 : i32}> : () -> i32 %6 = "neura.constant"() <{value = 0 : i64}> : () -> i64 + %7 = "neura.constant"() <{value = 2.5 : f32}> : () -> f32 + %8 = "neura.constant"() <{value = 1.0 : f32}> : () -> f32 + %9 = "neura.constant"() <{value = 0.0 : f32}> : () -> f32 + %10 = "neura.constant"() <{value = 10.0 : f32}> : () -> f32 neura.br %6 : i64 to ^bb1 - ^bb1(%7: i64): // 2 preds: ^bb0, ^bb2 - %8 = "neura.icmp"(%7, %3) <{cmpType = "slt"}> : (i64, i64) -> i1 - neura.cond_br %8 : i1 then to ^bb2 else to ^bb3 + ^bb1(%11: i64): // 2 preds: ^bb0, ^bb2 + %12 = "neura.icmp"(%11, %3) <{cmpType = "slt"}> : (i64, i64) -> i1 + neura.cond_br %12 : i1 then to ^bb2 else to ^bb3 ^bb2: // pred: ^bb1 - %9 = neura.load_indexed %0[%7 : i64] memref : i32 - %10 = "neura.mul"(%5, %9) : (i32, i32) -> i32 - %11 = "neura.add"(%4, %9) : (i32, i32) -> i32 - neura.store_indexed %11 to %1[%7 : i64] memref : i32 - %12 = "neura.add"(%7, %2) : (i64, i64) -> i64 - neura.br %12 : i64 to ^bb1 + %13 = neura.load_indexed %0[%11 : i64] memref : i32 + %14 = "neura.mul"(%5, %13) : (i32, i32) -> i32 + %15 = "neura.add"(%4, %13) : (i32, i32) -> i32 + neura.store_indexed %15 to %1[%11 : i64] memref : i32 + + // Test new float operations with constant folding + %16 = "neura.cast"(%13) <{cast_type = "sitofp"}> : (i32) -> f32 + %17 = "neura.fmul"(%16, %7) : (f32, f32) -> f32 + %18 = "neura.fsub"(%17, %8) : (f32, f32) -> f32 + %19 = "neura.fmax"(%18, %9) : (f32, f32) -> f32 + %20 = "neura.fmin"(%19, %10) : (f32, f32) -> f32 + + %21 = "neura.add"(%11, %2) : (i64, i64) -> i64 + neura.br %21 : i64 to ^bb1 ^bb3: // pred: ^bb1 "neura.return"() : () -> () } } // FOLD: func.func @_Z11simple_loopPiS_(%arg0: memref, %arg1: memref) attributes {accelerator = "neura", llvm.linkage = #llvm.linkage} { -// FOLD-NEXT: %0 = "neura.constant"() <{value = "%arg0"}> : () -> memref -// FOLD-NEXT: %1 = "neura.constant"() <{value = "%arg1"}> : () -> memref -// FOLD-NEXT: %2 = "neura.constant"() <{value = 0 : i64}> : () -> i64 -// FOLD-NEXT: neura.br %2 : i64 to ^bb1 -// FOLD-NEXT: ^bb1(%3: i64): // 2 preds: ^bb0, ^bb2 -// FOLD-NEXT: %4 = "neura.icmp"(%3) <{cmpType = "slt"}> {rhs_value = 128 : i64} : (i64) -> i1 -// FOLD-NEXT: neura.cond_br %4 : i1 then to ^bb2 else to ^bb3 +// FOLD-NEXT: %0 = "neura.constant"() <{value = 0 : i64}> : () -> i64 +// FOLD-NEXT: neura.br %0 : i64 to ^bb1 +// FOLD-NEXT: ^bb1(%1: i64): // 2 preds: ^bb0, ^bb2 +// FOLD-NEXT: %2 = "neura.icmp"(%1) <{cmpType = "slt"}> {rhs_value = 128 : i64} : (i64) -> i1 +// FOLD-NEXT: neura.cond_br %2 : i1 then to ^bb2 else to ^bb3 // FOLD-NEXT: ^bb2: // pred: ^bb1 -// FOLD-NEXT: %5 = neura.load_indexed %0[%3 : i64] memref : i32 -// FOLD-NEXT: %6 = "neura.mul"(%5) {rhs_value = 2 : i32} : (i32) -> i32 -// FOLD-NEXT: %7 = "neura.add"(%5) {rhs_value = 1 : i32} : (i32) -> i32 -// FOLD-NEXT: neura.store_indexed %7 to %1[%3 : i64] memref : i32 -// FOLD-NEXT: %8 = "neura.add"(%3) {rhs_value = 1 : i64} : (i64) -> i64 -// FOLD-NEXT: neura.br %8 : i64 to ^bb1 +// FOLD-NEXT: %3 = neura.load_indexed [%1 : i64] {lhs_value = "%arg0"} : i32 +// FOLD-NEXT: %4 = "neura.mul"(%3) {rhs_value = 2 : i32} : (i32) -> i32 +// FOLD-NEXT: %5 = "neura.add"(%3) {rhs_value = 1 : i32} : (i32) -> i32 +// FOLD-NEXT: neura.store_indexed %5 to [%1 : i64] {rhs_value = "%arg1"} : i32 +// FOLD-NEXT: %6 = "neura.cast"(%3) <{cast_type = "sitofp"}> : (i32) -> f32 +// FOLD-NEXT: %7 = "neura.fmul"(%6) {rhs_value = 2.500000e+00 : f32} : (f32) -> f32 +// FOLD-NEXT: %8 = "neura.fsub"(%7) {rhs_value = 1.000000e+00 : f32} : (f32) -> f32 +// FOLD-NEXT: %9 = "neura.fmax"(%8) {rhs_value = 0.000000e+00 : f32} : (f32) -> f32 +// FOLD-NEXT: %10 = "neura.fmin"(%9) {rhs_value = 1.000000e+01 : f32} : (f32) -> f32 +// FOLD-NEXT: %11 = "neura.add"(%1) {rhs_value = 1 : i64} : (i64) -> i64 +// FOLD-NEXT: neura.br %11 : i64 to ^bb1 // FOLD-NEXT: ^bb3: // pred: ^bb1 // FOLD-NEXT: "neura.return"() : () -> () diff --git a/tools/neura-interpreter/neura-interpreter.cpp b/tools/neura-interpreter/neura-interpreter.cpp index 96405e0f..af2b2d61 100644 --- a/tools/neura-interpreter/neura-interpreter.cpp +++ b/tools/neura-interpreter/neura-interpreter.cpp @@ -910,6 +910,154 @@ bool handleFDivOp( return true; } +/** + * @brief Handles the execution of a Neura floating-point maximum operation + * (neura.fmax) by computing the maximum of floating-point operands. + * + * This function processes Neura's floating-point maximum operations, which + * take 2-3 operands: two floating-point inputs (LHS and RHS) and an optional + * predicate operand. It calculates the maximum of the floating-point values + * (max(LHS, RHS)), combines the predicates of all operands (including the + * optional predicate if present), and stores the result in the value map. The + * operation requires at least two operands; fewer will result in an error. + * + * @param op The neura.fmax operation to handle + * @param value_to_predicated_data_map Reference to the map where the result + * will be stored, keyed by the + * operation's result value + * @return bool True if the floating-point maximum + * is successfully computed; false if there are fewer than 2 operands + */ +bool handleFMaxOp( + neura::FMaxOp op, + llvm::DenseMap &value_to_predicated_data_map) { + if (isVerboseMode()) { + llvm::outs() << "[neura-interpreter] Executing neura.fmax:\n"; + } + + if (op.getNumOperands() < 2) { + if (isVerboseMode()) { + llvm::errs() << "[neura-interpreter] └─ neura.fmax expects at least two " + "operands\n"; + } + return false; + } + + auto lhs = value_to_predicated_data_map[op.getOperand(0)]; + auto rhs = value_to_predicated_data_map[op.getOperand(1)]; + + if (isVerboseMode()) { + llvm::outs() << "[neura-interpreter] ├─ Operands \n"; + llvm::outs() << "[neura-interpreter] │ ├─ LHS : value = " << lhs.value + << " [pred = " << lhs.predicate << "]\n"; + llvm::outs() << "[neura-interpreter] │ └─ RHS : value = " << rhs.value + << " [pred = " << rhs.predicate << "]\n"; + } + + bool final_predicate = lhs.predicate && rhs.predicate; + + if (op.getNumOperands() > 2) { + auto pred = value_to_predicated_data_map[op.getOperand(2)]; + final_predicate = final_predicate && pred.predicate && (pred.value != 0.0f); + if (isVerboseMode()) { + llvm::outs() << "[neura-interpreter] ├─ Execution Context\n"; + llvm::outs() << "[neura-interpreter] │ └─ Pred : value = " << pred.value + << " [pred = " << pred.predicate << "]\n"; + } + } + + float lhs_float = static_cast(lhs.value); + float rhs_float = static_cast(rhs.value); + float result_float = std::max(lhs_float, rhs_float); + + PredicatedData result; + result.value = result_float; + result.predicate = final_predicate; + result.is_vector = false; + + if (isVerboseMode()) { + llvm::outs() << "[neura-interpreter] └─ Result : value = " << result.value + << " [pred = " << result.predicate << "]\n"; + } + + value_to_predicated_data_map[op.getResult()] = result; + return true; +} + +/** + * @brief Handles the execution of a Neura floating-point minimum operation + * (neura.fmin) by computing the minimum of floating-point operands. + * + * This function processes Neura's floating-point minimum operations, which + * take 2-3 operands: two floating-point inputs (LHS and RHS) and an optional + * predicate operand. It calculates the minimum of the floating-point values + * (min(LHS, RHS)), combines the predicates of all operands (including the + * optional predicate if present), and stores the result in the value map. The + * operation requires at least two operands; fewer will result in an error. + * + * @param op The neura.fmin operation to handle + * @param value_to_predicated_data_map Reference to the map where the result + * will be stored, keyed by the + * operation's result value + * @return bool True if the floating-point minimum + * is successfully computed; false if there are fewer than 2 operands + */ +bool handleFMinOp( + neura::FMinOp op, + llvm::DenseMap &value_to_predicated_data_map) { + if (isVerboseMode()) { + llvm::outs() << "[neura-interpreter] Executing neura.fmin:\n"; + } + + if (op.getNumOperands() < 2) { + if (isVerboseMode()) { + llvm::errs() << "[neura-interpreter] └─ neura.fmin expects at least two " + "operands\n"; + } + return false; + } + + auto lhs = value_to_predicated_data_map[op.getOperand(0)]; + auto rhs = value_to_predicated_data_map[op.getOperand(1)]; + + if (isVerboseMode()) { + llvm::outs() << "[neura-interpreter] ├─ Operands \n"; + llvm::outs() << "[neura-interpreter] │ ├─ LHS : value = " << lhs.value + << " [pred = " << lhs.predicate << "]\n"; + llvm::outs() << "[neura-interpreter] │ └─ RHS : value = " << rhs.value + << " [pred = " << rhs.predicate << "]\n"; + } + + bool final_predicate = lhs.predicate && rhs.predicate; + + if (op.getNumOperands() > 2) { + auto pred = value_to_predicated_data_map[op.getOperand(2)]; + final_predicate = final_predicate && pred.predicate && (pred.value != 0.0f); + if (isVerboseMode()) { + llvm::outs() << "[neura-interpreter] ├─ Execution Context\n"; + llvm::outs() << "[neura-interpreter] │ └─ Pred : value = " << pred.value + << " [pred = " << pred.predicate << "]\n"; + } + } + + float lhs_float = static_cast(lhs.value); + float rhs_float = static_cast(rhs.value); + float result_float = std::min(lhs_float, rhs_float); + + PredicatedData result; + result.value = result_float; + result.predicate = final_predicate; + result.is_vector = false; + + if (isVerboseMode()) { + llvm::outs() << "[neura-interpreter] └─ Result : value = " << result.value + << " [pred = " << result.predicate << "]\n"; + } + + value_to_predicated_data_map[op.getResult()] = result; + return true; +} + /** * @brief Handles the execution of a Neura vector floating-point * multiplication operation (neura.vfmul) by computing element-wise products @@ -3131,6 +3279,10 @@ OperationHandleResult handleOperation( result.success = handleFMulOp(fmul_op, value_to_predicated_data_map); } else if (auto fdiv_op = dyn_cast(op)) { result.success = handleFDivOp(fdiv_op, value_to_predicated_data_map); + } else if (auto fmax_op = dyn_cast(op)) { + result.success = handleFMaxOp(fmax_op, value_to_predicated_data_map); + } else if (auto fmin_op = dyn_cast(op)) { + result.success = handleFMinOp(fmin_op, value_to_predicated_data_map); } else if (auto vfmul_op = dyn_cast(op)) { result.success = handleVFMulOp(vfmul_op, value_to_predicated_data_map); } else if (auto fadd_fadd_op = dyn_cast(op)) { From 6dea9d68ea8f40c24f768692baa3efc0b17dc8da Mon Sep 17 00:00:00 2001 From: tangyz <739245980@qq.com> Date: Tue, 21 Oct 2025 13:37:23 +0800 Subject: [PATCH 3/4] Add NaN semantic support for FMax/FMin operations with maxnum/maximum and minnum/minimum --- include/NeuraDialect/NeuraOps.td | 22 +++++--- .../LlvmToNeura/LlvmToNeuraPass.cpp | 46 ++++++++++++++++- .../HwAgnosticOpt/FoldConstantPass.cpp | 4 +- .../interpreter/basic_operation/fmax.mlir | 50 ++++++++++++++++--- .../interpreter/basic_operation/fmin.mlir | 50 ++++++++++++++++--- tools/neura-interpreter/neura-interpreter.cpp | 48 +++++++++++++++++- 6 files changed, 196 insertions(+), 24 deletions(-) diff --git a/include/NeuraDialect/NeuraOps.td b/include/NeuraDialect/NeuraOps.td index cedf3d09..be6f12c7 100644 --- a/include/NeuraDialect/NeuraOps.td +++ b/include/NeuraDialect/NeuraOps.td @@ -109,15 +109,20 @@ def Neura_FMaxOp : Op { let summary = "Floating-point maximum operation"; let description = [{ Returns the maximum of two floating-point values. - Corresponds to llvm.maxnum intrinsic. + Supports two NaN propagation semantics: + - "maxnum": Returns non-NaN value when one operand is NaN (llvm.maxnum) + - "maximum": Propagates NaN when any operand is NaN (llvm.maximum) Example: - %result = neura.fmax %a, %b : f32 + %result = neura.fmax<"maxnum">(%a, %b : f32) : f32 -> f32 + %result = neura.fmax<"maximum">(%a, %b : f32) : f32 -> f32 }]; let opName = "fmax"; - let arguments = (ins AnyType:$lhs, Optional:$rhs); + let arguments = (ins AnyType:$lhs, Optional:$rhs, + DefaultValuedAttr:$nan_semantic); let results = (outs AnyType:$result); let traits = [SameOperandsAndResultElementType]; + let assemblyFormat = "`<` $nan_semantic `>` `(` $lhs (`,` $rhs^ `:` type($rhs))? `)` attr-dict `:` type($lhs) `->` type($result)"; } // Defines a floating-point minimum operation. @@ -125,15 +130,20 @@ def Neura_FMinOp : Op { let summary = "Floating-point minimum operation"; let description = [{ Returns the minimum of two floating-point values. - Corresponds to llvm.minnum intrinsic. + Supports two NaN propagation semantics: + - "minnum": Returns non-NaN value when one operand is NaN (llvm.minnum) + - "minimum": Propagates NaN when any operand is NaN (llvm.minimum) Example: - %result = neura.fmin %a, %b : f32 + %result = neura.fmin<"minnum">(%a, %b : f32) : f32 -> f32 + %result = neura.fmin<"minimum">(%a, %b : f32) : f32 -> f32 }]; let opName = "fmin"; - let arguments = (ins AnyType:$lhs, Optional:$rhs); + let arguments = (ins AnyType:$lhs, Optional:$rhs, + DefaultValuedAttr:$nan_semantic); let results = (outs AnyType:$result); let traits = [SameOperandsAndResultElementType]; + let assemblyFormat = "`<` $nan_semantic `>` `(` $lhs (`,` $rhs^ `:` type($rhs))? `)` attr-dict `:` type($lhs) `->` type($result)"; } // Defines a bitwise OR operation. diff --git a/lib/Conversion/LlvmToNeura/LlvmToNeuraPass.cpp b/lib/Conversion/LlvmToNeura/LlvmToNeuraPass.cpp index b24523b0..2151dfaf 100644 --- a/lib/Conversion/LlvmToNeura/LlvmToNeuraPass.cpp +++ b/lib/Conversion/LlvmToNeura/LlvmToNeuraPass.cpp @@ -147,7 +147,27 @@ struct LlvmMaxNumToNeuraFMax : public OpRewritePattern { if (!mlir::isa(resultType)) return failure(); - rewriter.replaceOpWithNewOp(op, resultType, lhs, rhs); + rewriter.replaceOpWithNewOp(op, resultType, lhs, rhs, + rewriter.getStringAttr("maxnum")); + return success(); + } +}; + +struct LlvmMaximumToNeuraFMax : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(LLVM::MaximumOp op, + PatternRewriter &rewriter) const override { + Value lhs = op->getOperand(0); + Value rhs = op->getOperand(1); + Type resultType = op->getResult(0).getType(); + + // Only matches scalar float. + if (!mlir::isa(resultType)) + return failure(); + + rewriter.replaceOpWithNewOp(op, resultType, lhs, rhs, + rewriter.getStringAttr("maximum")); return success(); } }; @@ -165,7 +185,27 @@ struct LlvmMinNumToNeuraFMin : public OpRewritePattern { if (!mlir::isa(resultType)) return failure(); - rewriter.replaceOpWithNewOp(op, resultType, lhs, rhs); + rewriter.replaceOpWithNewOp(op, resultType, lhs, rhs, + rewriter.getStringAttr("minnum")); + return success(); + } +}; + +struct LlvmMinimumToNeuraFMin : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(LLVM::MinimumOp op, + PatternRewriter &rewriter) const override { + Value lhs = op->getOperand(0); + Value rhs = op->getOperand(1); + Type resultType = op->getResult(0).getType(); + + // Only matches scalar float. + if (!mlir::isa(resultType)) + return failure(); + + rewriter.replaceOpWithNewOp(op, resultType, lhs, rhs, + rewriter.getStringAttr("minimum")); return success(); } }; @@ -622,7 +662,9 @@ struct LowerLlvmToNeuraPass patterns.add(&getContext()); patterns.add(&getContext()); patterns.add(&getContext()); + patterns.add(&getContext()); patterns.add(&getContext()); + patterns.add(&getContext()); patterns.add(&getContext()); patterns.add(&getContext()); patterns.add(&getContext()); diff --git a/lib/NeuraDialect/Transforms/Optimizations/HwAgnosticOpt/FoldConstantPass.cpp b/lib/NeuraDialect/Transforms/Optimizations/HwAgnosticOpt/FoldConstantPass.cpp index 86a86264..62787bac 100644 --- a/lib/NeuraDialect/Transforms/Optimizations/HwAgnosticOpt/FoldConstantPass.cpp +++ b/lib/NeuraDialect/Transforms/Optimizations/HwAgnosticOpt/FoldConstantPass.cpp @@ -285,7 +285,7 @@ struct FuseFMaxRhsConstantPattern PatternRewriter &rewriter) const override { auto fused_op = rewriter.create( op.getLoc(), op.getResult().getType(), non_const_operand, - /*rhs=*/nullptr); + /*rhs=*/nullptr, op.getNanSemantic()); addConstantAttribute(fused_op, "rhs_value", rhs_value); return fused_op; } @@ -303,7 +303,7 @@ struct FuseFMinRhsConstantPattern PatternRewriter &rewriter) const override { auto fused_op = rewriter.create( op.getLoc(), op.getResult().getType(), non_const_operand, - /*rhs=*/nullptr); + /*rhs=*/nullptr, op.getNanSemantic()); addConstantAttribute(fused_op, "rhs_value", rhs_value); return fused_op; } diff --git a/test/neura/interpreter/basic_operation/fmax.mlir b/test/neura/interpreter/basic_operation/fmax.mlir index 9e86374b..26d0aacb 100644 --- a/test/neura/interpreter/basic_operation/fmax.mlir +++ b/test/neura/interpreter/basic_operation/fmax.mlir @@ -6,7 +6,7 @@ func.func @test_fmax_positive() -> f32 { %a = arith.constant 10.0 : f32 %b = arith.constant 32.0 : f32 - %res = "neura.fmax"(%a, %b) : (f32, f32) -> f32 + %res = neura.fmax<"maxnum">(%a, %b : f32) : f32 -> f32 // CHECK: [neura-interpreter] → Output: 32.000000 return %res : f32 } @@ -17,7 +17,7 @@ func.func @test_fmax_positive() -> f32 { func.func @test_fmax_mixed() -> f32 { %a = arith.constant -5.0 : f32 %b = arith.constant 3.0 : f32 - %res = "neura.fmax"(%a, %b) : (f32, f32) -> f32 + %res = neura.fmax<"maxnum">(%a, %b : f32) : f32 -> f32 // CHECK: [neura-interpreter] → Output: 3.000000 return %res : f32 } @@ -28,7 +28,7 @@ func.func @test_fmax_mixed() -> f32 { func.func @test_fmax_negative() -> f32 { %a = arith.constant -10.0 : f32 %b = arith.constant -3.0 : f32 - %res = "neura.fmax"(%a, %b) : (f32, f32) -> f32 + %res = neura.fmax<"maxnum">(%a, %b : f32) : f32 -> f32 // CHECK: [neura-interpreter] → Output: -3.000000 return %res : f32 } @@ -39,7 +39,7 @@ func.func @test_fmax_negative() -> f32 { func.func @test_fmax_zero() -> f32 { %a = arith.constant 0.0 : f32 %b = arith.constant -7.0 : f32 - %res = "neura.fmax"(%a, %b) : (f32, f32) -> f32 + %res = neura.fmax<"maxnum">(%a, %b : f32) : f32 -> f32 // CHECK: [neura-interpreter] → Output: 0.000000 return %res : f32 } @@ -50,7 +50,7 @@ func.func @test_fmax_zero() -> f32 { func.func @test_fmax_equal() -> f32 { %a = arith.constant 5.5 : f32 %b = arith.constant 5.5 : f32 - %res = "neura.fmax"(%a, %b) : (f32, f32) -> f32 + %res = neura.fmax<"maxnum">(%a, %b : f32) : f32 -> f32 // CHECK: [neura-interpreter] → Output: 5.500000 return %res : f32 } @@ -61,8 +61,46 @@ func.func @test_fmax_equal() -> f32 { func.func @test_fmax_fraction() -> f32 { %a = arith.constant 2.75 : f32 %b = arith.constant 2.5 : f32 - %res = "neura.fmax"(%a, %b) : (f32, f32) -> f32 + %res = neura.fmax<"maxnum">(%a, %b : f32) : f32 -> f32 // CHECK: [neura-interpreter] → Output: 2.750000 return %res : f32 } +// ===----------------------------------------------------------------------===// +// Test 7: FMax with NaN (maxnum semantic) +// ===----------------------------------------------------------------------===// +func.func @test_fmax_nan_maxnum_lhs() -> f32 { + %nan = arith.constant 0x7FC00000 : f32 // NaN + %b = arith.constant 5.0 : f32 + %res = neura.fmax<"maxnum">(%nan, %b : f32) : f32 -> f32 + // CHECK: [neura-interpreter] → Output: 5.000000 + return %res : f32 +} + +func.func @test_fmax_nan_maxnum_rhs() -> f32 { + %a = arith.constant 5.0 : f32 + %nan = arith.constant 0x7FC00000 : f32 // NaN + %res = neura.fmax<"maxnum">(%a, %nan : f32) : f32 -> f32 + // CHECK: [neura-interpreter] → Output: 5.000000 + return %res : f32 +} + +// ===----------------------------------------------------------------------===// +// Test 8: FMax with NaN (maximum semantic) +// ===----------------------------------------------------------------------===// +func.func @test_fmax_nan_maximum_lhs() -> f32 { + %nan = arith.constant 0x7FC00000 : f32 // NaN + %b = arith.constant 5.0 : f32 + %res = neura.fmax<"maximum">(%nan, %b : f32) : f32 -> f32 + // CHECK: [neura-interpreter] → Output: nan + return %res : f32 +} + +func.func @test_fmax_nan_maximum_rhs() -> f32 { + %a = arith.constant 5.0 : f32 + %nan = arith.constant 0x7FC00000 : f32 // NaN + %res = neura.fmax<"maximum">(%a, %nan : f32) : f32 -> f32 + // CHECK: [neura-interpreter] → Output: nan + return %res : f32 +} + diff --git a/test/neura/interpreter/basic_operation/fmin.mlir b/test/neura/interpreter/basic_operation/fmin.mlir index afba4bb9..601d0b8a 100644 --- a/test/neura/interpreter/basic_operation/fmin.mlir +++ b/test/neura/interpreter/basic_operation/fmin.mlir @@ -6,7 +6,7 @@ func.func @test_fmin_positive() -> f32 { %a = arith.constant 10.0 : f32 %b = arith.constant 32.0 : f32 - %res = "neura.fmin"(%a, %b) : (f32, f32) -> f32 + %res = neura.fmin<"minnum">(%a, %b : f32) : f32 -> f32 // CHECK: [neura-interpreter] → Output: 10.000000 return %res : f32 } @@ -17,7 +17,7 @@ func.func @test_fmin_positive() -> f32 { func.func @test_fmin_mixed() -> f32 { %a = arith.constant -5.0 : f32 %b = arith.constant 3.0 : f32 - %res = "neura.fmin"(%a, %b) : (f32, f32) -> f32 + %res = neura.fmin<"minnum">(%a, %b : f32) : f32 -> f32 // CHECK: [neura-interpreter] → Output: -5.000000 return %res : f32 } @@ -28,7 +28,7 @@ func.func @test_fmin_mixed() -> f32 { func.func @test_fmin_negative() -> f32 { %a = arith.constant -10.0 : f32 %b = arith.constant -3.0 : f32 - %res = "neura.fmin"(%a, %b) : (f32, f32) -> f32 + %res = neura.fmin<"minnum">(%a, %b : f32) : f32 -> f32 // CHECK: [neura-interpreter] → Output: -10.000000 return %res : f32 } @@ -39,7 +39,7 @@ func.func @test_fmin_negative() -> f32 { func.func @test_fmin_zero() -> f32 { %a = arith.constant 0.0 : f32 %b = arith.constant 7.0 : f32 - %res = "neura.fmin"(%a, %b) : (f32, f32) -> f32 + %res = neura.fmin<"minnum">(%a, %b : f32) : f32 -> f32 // CHECK: [neura-interpreter] → Output: 0.000000 return %res : f32 } @@ -50,7 +50,7 @@ func.func @test_fmin_zero() -> f32 { func.func @test_fmin_equal() -> f32 { %a = arith.constant 5.5 : f32 %b = arith.constant 5.5 : f32 - %res = "neura.fmin"(%a, %b) : (f32, f32) -> f32 + %res = neura.fmin<"minnum">(%a, %b : f32) : f32 -> f32 // CHECK: [neura-interpreter] → Output: 5.500000 return %res : f32 } @@ -61,8 +61,46 @@ func.func @test_fmin_equal() -> f32 { func.func @test_fmin_fraction() -> f32 { %a = arith.constant 2.75 : f32 %b = arith.constant 2.5 : f32 - %res = "neura.fmin"(%a, %b) : (f32, f32) -> f32 + %res = neura.fmin<"minnum">(%a, %b : f32) : f32 -> f32 // CHECK: [neura-interpreter] → Output: 2.500000 return %res : f32 } +// ===----------------------------------------------------------------------===// +// Test 7: FMin with NaN (minnum semantic) +// ===----------------------------------------------------------------------===// +func.func @test_fmin_nan_minnum_lhs() -> f32 { + %nan = arith.constant 0x7FC00000 : f32 // NaN + %b = arith.constant 5.0 : f32 + %res = neura.fmin<"minnum">(%nan, %b : f32) : f32 -> f32 + // CHECK: [neura-interpreter] → Output: 5.000000 + return %res : f32 +} + +func.func @test_fmin_nan_minnum_rhs() -> f32 { + %a = arith.constant 5.0 : f32 + %nan = arith.constant 0x7FC00000 : f32 // NaN + %res = neura.fmin<"minnum">(%a, %nan : f32) : f32 -> f32 + // CHECK: [neura-interpreter] → Output: 5.000000 + return %res : f32 +} + +// ===----------------------------------------------------------------------===// +// Test 8: FMin with NaN (minimum semantic) +// ===----------------------------------------------------------------------===// +func.func @test_fmin_nan_minimum_lhs() -> f32 { + %nan = arith.constant 0x7FC00000 : f32 // NaN + %b = arith.constant 5.0 : f32 + %res = neura.fmin<"minimum">(%nan, %b : f32) : f32 -> f32 + // CHECK: [neura-interpreter] → Output: nan + return %res : f32 +} + +func.func @test_fmin_nan_minimum_rhs() -> f32 { + %a = arith.constant 5.0 : f32 + %nan = arith.constant 0x7FC00000 : f32 // NaN + %res = neura.fmin<"minimum">(%a, %nan : f32) : f32 -> f32 + // CHECK: [neura-interpreter] → Output: nan + return %res : f32 +} + diff --git a/tools/neura-interpreter/neura-interpreter.cpp b/tools/neura-interpreter/neura-interpreter.cpp index af2b2d61..705f97b0 100644 --- a/tools/neura-interpreter/neura-interpreter.cpp +++ b/tools/neura-interpreter/neura-interpreter.cpp @@ -968,7 +968,28 @@ bool handleFMaxOp( float lhs_float = static_cast(lhs.value); float rhs_float = static_cast(rhs.value); - float result_float = std::max(lhs_float, rhs_float); + + // Get NaN semantic attribute (default is "maxnum") + std::string nan_semantic = op.getNanSemantic().str(); + float result_float; + + if (nan_semantic == "maxnum") { + // maxnum semantic: return non-NaN value when one operand is NaN + if (std::isnan(lhs_float) && !std::isnan(rhs_float)) { + result_float = rhs_float; + } else if (std::isnan(rhs_float) && !std::isnan(lhs_float)) { + result_float = lhs_float; + } else { + result_float = std::max(lhs_float, rhs_float); + } + } else { // "maximum" + // maximum semantic: propagate NaN when any operand is NaN + if (std::isnan(lhs_float) || std::isnan(rhs_float)) { + result_float = std::nan(""); + } else { + result_float = std::max(lhs_float, rhs_float); + } + } PredicatedData result; result.value = result_float; @@ -976,6 +997,7 @@ bool handleFMaxOp( result.is_vector = false; if (isVerboseMode()) { + llvm::outs() << "[neura-interpreter] ├─ NaN semantic: " << nan_semantic << "\n"; llvm::outs() << "[neura-interpreter] └─ Result : value = " << result.value << " [pred = " << result.predicate << "]\n"; } @@ -1042,7 +1064,28 @@ bool handleFMinOp( float lhs_float = static_cast(lhs.value); float rhs_float = static_cast(rhs.value); - float result_float = std::min(lhs_float, rhs_float); + + // Get NaN semantic attribute (default is "minnum") + std::string nan_semantic = op.getNanSemantic().str(); + float result_float; + + if (nan_semantic == "minnum") { + // minnum semantic: return non-NaN value when one operand is NaN + if (std::isnan(lhs_float) && !std::isnan(rhs_float)) { + result_float = rhs_float; + } else if (std::isnan(rhs_float) && !std::isnan(lhs_float)) { + result_float = lhs_float; + } else { + result_float = std::min(lhs_float, rhs_float); + } + } else { // "minimum" + // minimum semantic: propagate NaN when any operand is NaN + if (std::isnan(lhs_float) || std::isnan(rhs_float)) { + result_float = std::nan(""); + } else { + result_float = std::min(lhs_float, rhs_float); + } + } PredicatedData result; result.value = result_float; @@ -1050,6 +1093,7 @@ bool handleFMinOp( result.is_vector = false; if (isVerboseMode()) { + llvm::outs() << "[neura-interpreter] ├─ NaN semantic: " << nan_semantic << "\n"; llvm::outs() << "[neura-interpreter] └─ Result : value = " << result.value << " [pred = " << result.predicate << "]\n"; } From 9e98c66278252b494ebc8177aa596b065c871755 Mon Sep 17 00:00:00 2001 From: tangyz <739245980@qq.com> Date: Tue, 21 Oct 2025 13:45:36 +0800 Subject: [PATCH 4/4] Update constant folding test for new FMax/FMin syntax with NaN semantics --- test/optimization/constant_folding/simple_loop.mlir | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/optimization/constant_folding/simple_loop.mlir b/test/optimization/constant_folding/simple_loop.mlir index 588fa494..5859b4f2 100644 --- a/test/optimization/constant_folding/simple_loop.mlir +++ b/test/optimization/constant_folding/simple_loop.mlir @@ -54,8 +54,8 @@ module { // FOLD-NEXT: %6 = "neura.cast"(%3) <{cast_type = "sitofp"}> : (i32) -> f32 // FOLD-NEXT: %7 = "neura.fmul"(%6) {rhs_value = 2.500000e+00 : f32} : (f32) -> f32 // FOLD-NEXT: %8 = "neura.fsub"(%7) {rhs_value = 1.000000e+00 : f32} : (f32) -> f32 -// FOLD-NEXT: %9 = "neura.fmax"(%8) {rhs_value = 0.000000e+00 : f32} : (f32) -> f32 -// FOLD-NEXT: %10 = "neura.fmin"(%9) {rhs_value = 1.000000e+01 : f32} : (f32) -> f32 +// FOLD-NEXT: %9 = neura.fmax<"maxnum"> (%8) {rhs_value = 0.000000e+00 : f32} : f32 -> f32 +// FOLD-NEXT: %10 = neura.fmin<"minnum"> (%9) {rhs_value = 1.000000e+01 : f32} : f32 -> f32 // FOLD-NEXT: %11 = "neura.add"(%1) {rhs_value = 1 : i64} : (i64) -> i64 // FOLD-NEXT: neura.br %11 : i64 to ^bb1 // FOLD-NEXT: ^bb3: // pred: ^bb1