From 1a0b379c2a797ff7fb85e1483ebcd9cb7213a67b Mon Sep 17 00:00:00 2001 From: Alex Light Date: Tue, 21 Apr 2026 10:19:39 -0700 Subject: [PATCH] Fix Z3 translation of ArraySlice and add fuzzing infrastructure. The Z3 translation for ArraySlice is updated to correctly handle index addition by zero-extending the start index to prevent overflow. New tests are added to verify the equivalence of ArraySlice and ArrayIndex operations against alternative IR structures using both exhaustive simulation and Z3 proving. A fuzzing framework is introduced, which generates random IR, exhaustively interprets it to find all possible return values, and then builds a Z3 query to check if the fuzzed function's return value is always within this set of possible values. The fuzzer should hopefully (very slowly) prove that we do not have any similar translation bugs. PiperOrigin-RevId: 903299223 --- xls/solvers/BUILD | 24 ++ .../z3_ir_equivalence_testutils_test.cc | 51 +++ xls/solvers/z3_ir_translator.cc | 18 +- xls/solvers/z3_ir_translator_test.cc | 307 ++++++++++++++++++ 4 files changed, 395 insertions(+), 5 deletions(-) diff --git a/xls/solvers/BUILD b/xls/solvers/BUILD index 1c31f68e68..155c6aefb9 100644 --- a/xls/solvers/BUILD +++ b/xls/solvers/BUILD @@ -158,10 +158,17 @@ cc_test( ":z3_ir_equivalence_testutils", "//xls/common:xls_gunit_main", "//xls/common/status:matchers", + "//xls/common/status:ret_check", + "//xls/interpreter:ir_interpreter", "//xls/ir", + "//xls/ir:bits", + "//xls/ir:events", "//xls/ir:function_builder", "//xls/ir:ir_test_base", "//xls/ir:value", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/status", + "@com_google_absl//absl/types:span", "@googletest//:gtest", ], ) @@ -211,11 +218,18 @@ cc_test( ":z3_ir_translator_matchers", ":z3_utils", "//xls/common:xls_gunit_main", + "//xls/common/fuzzing:fuzztest", "//xls/common/status:matchers", + "//xls/common/status:ret_check", + "//xls/common/status:status_macros", + "//xls/data_structures:inline_bitmap", + "//xls/data_structures:leaf_type_tree", + "//xls/fuzzer/ir_fuzzer:ir_fuzz_domain", "//xls/interpreter:ir_interpreter", "//xls/ir", "//xls/ir:bits", "//xls/ir:bits_ops", + "//xls/ir:events", "//xls/ir:format_preference", "//xls/ir:function_builder", "//xls/ir:ir_parser", @@ -224,8 +238,18 @@ cc_test( "//xls/ir:source_location", "//xls/ir:value", "//xls/ir:value_builder", + "//xls/ir:value_flattening", + "//xls/ir:value_utils", + "//xls/passes:dfe_pass", + "//xls/passes:inlining_pass", + "//xls/passes:optimization_pass", + "//xls/passes:pass_base", + "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/cleanup", "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/log", + "@com_google_absl//absl/log:check", "@com_google_absl//absl/status", "@com_google_absl//absl/status:status_matchers", "@com_google_absl//absl/strings", diff --git a/xls/solvers/z3_ir_equivalence_testutils_test.cc b/xls/solvers/z3_ir_equivalence_testutils_test.cc index a9fd9ee5ca..4ad288eef8 100644 --- a/xls/solvers/z3_ir_equivalence_testutils_test.cc +++ b/xls/solvers/z3_ir_equivalence_testutils_test.cc @@ -14,8 +14,17 @@ #include "xls/solvers/z3_ir_equivalence_testutils.h" +#include + #include "gtest/gtest.h" +#include "absl/container/flat_hash_map.h" +#include "absl/status/status.h" +#include "absl/types/span.h" #include "xls/common/status/matchers.h" +#include "xls/common/status/ret_check.h" +#include "xls/interpreter/ir_interpreter.h" +#include "xls/ir/bits.h" +#include "xls/ir/events.h" #include "xls/ir/function_builder.h" #include "xls/ir/ir_test_base.h" #include "xls/ir/nodes.h" @@ -42,5 +51,47 @@ TEST_F(Z3IrEquivalenceTestutilsTest, EquivWithAssert) { XLS_ASSERT_OK(f->RemoveNode(add.node())); } +class FunctionInterpreter : public IrInterpreter { + public: + using IrInterpreter::IrInterpreter; + absl::Status HandleParam(Param* param) override { + XLS_RET_CHECK(HasResult(param)) << param; + return absl::OkStatus(); + } +}; + +TEST_F(Z3IrEquivalenceTestutilsTest, EquivArraySlice) { + auto p = CreatePackage(); + FunctionBuilder fb(TestName(), p.get()); + BValue x = fb.Param("x", p->GetArrayType(2, p->GetBitsType(4))); + BValue y = fb.Param("y", p->GetBitsType(4)); + BValue eq_0 = fb.Eq(y, fb.Literal(UBits(0, 4))); + fb.ArrayIndex(x, {fb.Literal(UBits(0, 4))}); + BValue element_1 = fb.ArrayIndex(x, {fb.Literal(UBits(1, 4))}); + BValue end = fb.Array({element_1, element_1}, element_1.GetType()); + BValue transformed = fb.Select(eq_0, /*on_true=*/x, /*on_false=*/end); + BValue res = fb.ArraySlice(x, y, 2); + XLS_ASSERT_OK_AND_ASSIGN(Function * f, fb.Build()); + ScopedVerifyEquivalence sve(f); + // Exhaustively check all inputs for equivalence. + for (uint64_t x0 = 0; x0 < 16; ++x0) { + for (uint64_t x1 = 0; x1 < 16; ++x1) { + for (uint64_t y0 = 0; y0 < 16; ++y0) { + InterpreterEvents events; + XLS_ASSERT_OK_AND_ASSIGN(Value x_val, Value::UBitsArray({x0, x1}, 4)); + absl::flat_hash_map node_values{ + {x.node(), x_val}, {y.node(), Value(UBits(y0, 4))}}; + FunctionInterpreter interp(&node_values, &events); + XLS_ASSERT_OK(f->Accept(&interp)); + EXPECT_EQ(node_values.at(res.node()), + node_values.at(transformed.node())) + << " @ [" << x0 << ", " << x1 << "], " << y; + } + } + } + XLS_ASSERT_OK( + res.node()->ReplaceImplicitUsesWith(transformed.node()).status()); +} + } // namespace } // namespace xls::solvers::z3 diff --git a/xls/solvers/z3_ir_translator.cc b/xls/solvers/z3_ir_translator.cc index 288ec012c4..dabe47bc85 100644 --- a/xls/solvers/z3_ir_translator.cc +++ b/xls/solvers/z3_ir_translator.cc @@ -926,13 +926,21 @@ absl::Status IrTranslator::HandleArraySlice(ArraySlice* array_slice) { Z3_ast start_ast = GetValue(array_slice->start()); ArrayType* input_type = array_slice->array()->GetType()->AsArrayOrDie(); ArrayType result_type(array_slice->width(), input_type->element_type()); - Z3_ast formatted_start_ast = - GetAsFormattedArrayIndex(ctx_, start_ast, input_type); - + int64_t min_offset_bits = Bits::MinBitCountUnsigned(array_slice->width()); + // Make sure we don't overflow. + int64_t offset_width = + 1 + std::max( + min_offset_bits, + Z3_get_bv_sort_size(ctx_, Z3_get_sort(ctx_, start_ast))); + Z3_sort index_sort = Z3_mk_bv_sort(ctx_, offset_width); + Z3_ast start_ext = Z3_mk_zero_ext( + ctx_, + offset_width - Z3_get_bv_sort_size(ctx_, Z3_get_sort(ctx_, start_ast)), + start_ast); std::vector elements; for (uint64_t i = 0; i < array_slice->width(); ++i) { - Z3_ast i_ast = GetAsFormattedArrayIndex(ctx_, i, input_type); - Z3_ast index_ast = Z3_mk_bvadd(ctx_, i_ast, formatted_start_ast); + Z3_ast i_ast = Z3_mk_int64(ctx_, i, index_sort); + Z3_ast index_ast = Z3_mk_bvadd(ctx_, start_ext, i_ast); elements.push_back(GetArrayElement(input_type, array_ast, index_ast)); } diff --git a/xls/solvers/z3_ir_translator_test.cc b/xls/solvers/z3_ir_translator_test.cc index 8b3b8e5202..7da319bc73 100644 --- a/xls/solvers/z3_ir_translator_test.cc +++ b/xls/solvers/z3_ir_translator_test.cc @@ -16,6 +16,7 @@ #include #include +#include #include #include #include @@ -25,17 +26,29 @@ #include "gmock/gmock.h" #include "gtest/gtest.h" +#include "xls/common/fuzzing/fuzztest.h" +#include "absl/algorithm/container.h" #include "absl/cleanup/cleanup.h" #include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" +#include "absl/log/check.h" +#include "absl/log/log.h" #include "absl/status/status.h" #include "absl/status/status_matchers.h" +#include "absl/strings/str_cat.h" #include "absl/strings/substitute.h" #include "absl/time/time.h" #include "absl/types/span.h" #include "xls/common/status/matchers.h" +#include "xls/common/status/ret_check.h" +#include "xls/common/status/status_macros.h" +#include "xls/data_structures/inline_bitmap.h" +#include "xls/data_structures/leaf_type_tree.h" +#include "xls/fuzzer/ir_fuzzer/ir_fuzz_domain.h" #include "xls/interpreter/ir_interpreter.h" #include "xls/ir/bits.h" #include "xls/ir/bits_ops.h" +#include "xls/ir/events.h" #include "xls/ir/format_preference.h" #include "xls/ir/function_builder.h" #include "xls/ir/ir_parser.h" @@ -47,6 +60,12 @@ #include "xls/ir/source_location.h" #include "xls/ir/value.h" #include "xls/ir/value_builder.h" +#include "xls/ir/value_flattening.h" +#include "xls/ir/value_utils.h" +#include "xls/passes/dfe_pass.h" +#include "xls/passes/inlining_pass.h" +#include "xls/passes/optimization_pass.h" +#include "xls/passes/pass_base.h" #include "xls/solvers/z3_ir_translator_matchers.h" #include "xls/solvers/z3_utils.h" #include "z3/src/api/z3.h" // IWYU pragma: keep @@ -2758,5 +2777,293 @@ TEST_F(Z3IrTranslatorTest, DumpWithNodeValues) { ContainsRegex("y: bits\\[32\\] id=[0-9]+ \\(0\\)"))); } +class FunctionInterpreter : public IrInterpreter { + public: + using IrInterpreter::IrInterpreter; + absl::Status HandleParam(Param* param) override { + XLS_RET_CHECK(HasResult(param)) << param; + return absl::OkStatus(); + } +}; + +TEST_F(Z3IrTranslatorTest, EquivArraySlice) { + auto p = CreatePackage(); + FunctionBuilder fb(TestName(), p.get()); + BValue x = fb.Param("x", p->GetArrayType(2, p->GetBitsType(4))); + BValue y = fb.Param("y", p->GetBitsType(4)); + BValue eq_0 = fb.Eq(y, fb.Literal(UBits(0, 4))); + BValue element_1 = fb.ArrayIndex(x, {fb.Literal(UBits(1, 4))}); + BValue end = fb.Array({element_1, element_1}, element_1.GetType()); + BValue transformed = fb.Select(eq_0, /*on_true=*/x, /*on_false=*/end); + BValue orig = fb.ArraySlice(x, y, 2); + fb.Eq(transformed, orig); + XLS_ASSERT_OK_AND_ASSIGN(Function * f, fb.Build()); + // Exhaustively check all inputs for equivalence. + for (uint64_t x0 = 0; x0 < 16; ++x0) { + for (uint64_t x1 = 0; x1 < 16; ++x1) { + for (uint64_t y0 = 0; y0 < 16; ++y0) { + InterpreterEvents events; + XLS_ASSERT_OK_AND_ASSIGN(Value x_val, Value::UBitsArray({x0, x1}, 4)); + absl::flat_hash_map node_values{ + {x.node(), x_val}, {y.node(), Value(UBits(y0, 4))}}; + FunctionInterpreter interp(&node_values, &events); + XLS_ASSERT_OK(f->Accept(&interp)); + EXPECT_EQ(node_values.at(orig.node()), + node_values.at(transformed.node())) + << " @ [" << x0 << ", " << x1 << "], " << y; + } + } + } + XLS_ASSERT_OK_AND_ASSIGN( + ProverResult res, + TryProve(f, f->return_value(), Predicate::NotEqualToZero(), + absl::InfiniteDuration())); + EXPECT_THAT(res, IsProvenTrue()) + << f->DumpIr(solvers::z3::CounterExampleAnnotator( + std::get(res))); + RecordProperty("smtlib", solvers::z3::EmitFunctionAsSmtLib(f).value_or( + "Unable to emit smtlib")); +} +TEST_F(Z3IrTranslatorTest, ArraySliceBigger) { + auto p = CreatePackage(); + FunctionBuilder fb(TestName(), p.get()); + BValue x = fb.Param("x", p->GetArrayType(2, p->GetBitsType(4))); + BValue y = fb.Param("y", p->GetBitsType(4)); + BValue element_1 = fb.ArrayIndex(x, {fb.Literal(UBits(1, 4))}); + BValue orig = fb.ArraySlice(x, y, 200); + fb.Eq(element_1, fb.ArrayIndex(orig, {fb.Add(fb.SignExtend(y, 8), + fb.Literal(UBits(32, 8)))})); + XLS_ASSERT_OK_AND_ASSIGN(Function * f, fb.Build()); + XLS_ASSERT_OK_AND_ASSIGN( + ProverResult res, + TryProve(f, f->return_value(), Predicate::NotEqualToZero(), + absl::InfiniteDuration())); + EXPECT_THAT(res, IsProvenTrue()) + << f->DumpIr(solvers::z3::CounterExampleAnnotator( + std::get(res))); + RecordProperty("smtlib", solvers::z3::EmitFunctionAsSmtLib(f).value_or( + "Unable to emit smtlib")); +} +TEST_F(Z3IrTranslatorTest, ArraySliceTranslate) { + auto p = CreatePackage(); + FunctionBuilder fb(TestName(), p.get()); + BValue x = fb.Param("x", p->GetArrayType(2, p->GetBitsType(4))); + BValue y = fb.Param("y", p->GetBitsType(4)); + BValue eq_0 = fb.Eq(y, fb.Literal(UBits(0, 4))); + BValue element_1 = fb.ArrayIndex(x, {fb.Literal(UBits(1, 4))}); + // BValue end = fb.Array({element_1, element_1}, element_1.GetType()); + BValue orig = fb.ArraySlice(x, y, 2); + BValue ret_val = fb.Or( + fb.And(fb.Eq(element_1, fb.ArrayIndex(orig, {fb.Literal(UBits(0, 1))})), + fb.Eq(element_1, fb.ArrayIndex(orig, {fb.Literal(UBits(1, 1))}))), + eq_0); + XLS_ASSERT_OK_AND_ASSIGN(Function * f, fb.Build()); + // Exhaustively check all inputs for equivalence. + for (uint64_t x0 = 0; x0 < 16; ++x0) { + for (uint64_t x1 = 0; x1 < 16; ++x1) { + for (uint64_t y0 = 0; y0 < 16; ++y0) { + InterpreterEvents events; + XLS_ASSERT_OK_AND_ASSIGN(Value x_val, Value::UBitsArray({x0, x1}, 4)); + absl::flat_hash_map node_values{ + {x.node(), x_val}, {y.node(), Value(UBits(y0, 4))}}; + FunctionInterpreter interp(&node_values, &events); + XLS_ASSERT_OK(f->Accept(&interp)); + EXPECT_EQ(node_values.at(ret_val.node()), Value(UBits(1, 1))) + << " @ [" << x0 << ", " << x1 << "], " << y; + } + } + } + XLS_ASSERT_OK_AND_ASSIGN( + ProverResult res, + TryProve(f, f->return_value(), Predicate::NotEqualToZero(), + absl::InfiniteDuration())); + EXPECT_THAT(res, IsProvenTrue()) + << f->DumpIr(solvers::z3::CounterExampleAnnotator( + std::get(res))); + RecordProperty("smtlib", solvers::z3::EmitFunctionAsSmtLib(f).value_or( + "Unable to emit smtlib")); +} +TEST_F(Z3IrTranslatorTest, EquivArrayIndex) { + auto p = CreatePackage(); + FunctionBuilder fb(TestName(), p.get()); + BValue x = fb.Param("x", p->GetArrayType(2, p->GetBitsType(4))); + BValue y = fb.Param("y", p->GetBitsType(4)); + BValue eq_0 = fb.Eq(y, fb.Literal(UBits(0, 4))); + BValue element_0 = fb.ArrayIndex(x, {fb.Literal(UBits(0, 4))}); + BValue element_1 = fb.ArrayIndex(x, {fb.Literal(UBits(1, 4))}); + BValue transformed = + fb.Select(eq_0, /*on_true=*/element_0, /*on_false=*/element_1); + BValue orig = fb.ArrayIndex(x, {y}); + fb.Eq(transformed, orig); + XLS_ASSERT_OK_AND_ASSIGN(Function * f, fb.Build()); + // Exhaustively check all inputs for equivalence. + for (uint64_t x0 = 0; x0 < 16; ++x0) { + for (uint64_t x1 = 0; x1 < 16; ++x1) { + for (uint64_t y0 = 0; y0 < 16; ++y0) { + InterpreterEvents events; + XLS_ASSERT_OK_AND_ASSIGN(Value x_val, Value::UBitsArray({x0, x1}, 4)); + absl::flat_hash_map node_values{ + {x.node(), x_val}, {y.node(), Value(UBits(y0, 4))}}; + FunctionInterpreter interp(&node_values, &events); + XLS_ASSERT_OK(f->Accept(&interp)); + EXPECT_EQ(node_values.at(orig.node()), + node_values.at(transformed.node())) + << " @ [" << x0 << ", " << x1 << "], " << y; + } + } + } + XLS_ASSERT_OK_AND_ASSIGN( + ProverResult res, + TryProve(f, f->return_value(), Predicate::NotEqualToZero(), + absl::InfiniteDuration())); + EXPECT_THAT(res, IsProvenTrue()) + << f->DumpIr(solvers::z3::CounterExampleAnnotator( + std::get(res))); + RecordProperty("smtlib", solvers::z3::EmitFunctionAsSmtLib(f).value_or( + "Unable to emit smtlib")); +} + +void Z3TranslationTest(std::shared_ptr package) { + // Get all the possible values this function can return. + // FunctionInterpreter interp(); + XLS_ASSERT_OK_AND_ASSIGN(Function * f, package->GetTopAsFunction()); + XLS_ASSERT_OK_AND_ASSIGN( + ProverResult res, + TryProve(f, f->return_value(), Predicate::NotEqualToZero(), + absl::InfiniteDuration())); + EXPECT_THAT(res, IsProvenTrue()) + << f->DumpIr(solvers::z3::CounterExampleAnnotator( + std::get(res))); +} +namespace { +absl::Status AddResultChecks(std::shared_ptr package) { + XLS_ASSIGN_OR_RETURN(Function * func, package->GetTopAsFunction()); + LeafTypeTree> possible_return_values( + func->return_type(), absl::flat_hash_set{}); + int64_t num_inputs = 0; + for (Param* param : func->params()) { + num_inputs += param->GetType()->GetFlatBitCount(); + } + CHECK_LE(num_inputs, 20) << func->DumpIr(); + for (int64_t i = 0; i < (1 << num_inputs); ++i) { + auto bits = UBits(i, num_inputs); + auto it = bits.begin(); + absl::flat_hash_map node_values; + for (Param* param : func->params()) { + InlineBitmap value_bits(param->GetType()->GetFlatBitCount()); + for (int j = 0; j < param->GetType()->GetFlatBitCount(); ++j) { + value_bits.Set(j, *it); + ++it; + } + XLS_ASSIGN_OR_RETURN( + Value val, + UnflattenBitsToValue(Bits::FromBitmap(std::move(value_bits)), + param->GetType())); + node_values[param] = val; + } + FunctionInterpreter interp(&node_values, nullptr); + XLS_RETURN_IF_ERROR(func->Accept(&interp)); + XLS_ASSIGN_OR_RETURN( + LeafTypeTree ltt, + ValueToBitsLeafTypeTree(node_values.at(func->return_value()), + func->return_type())); + XLS_RETURN_IF_ERROR( + (leaf_type_tree::UpdateFrom, Bits>( + possible_return_values.AsMutableView(), ltt.AsView(), + [](Type*, absl::flat_hash_set& possible_values, + const Bits& element, absl::Span) { + possible_values.insert(element); + return absl::OkStatus(); + }))); + } + LeafTypeTree> possible_return_values_set; + FunctionBuilder fb(absl::StrCat("check_", func->name()), package.get()); + std::vector params; + for (Param* param : func->params()) { + params.push_back(fb.Param(param->name(), param->GetType())); + } + BValue res = fb.Invoke(params, func); + LeafTypeTree return_values = fb.MakeLeafTypeTree(res); + LeafTypeTree checks = + leaf_type_tree::Zip>( + return_values.AsView(), possible_return_values.AsView(), + [&](const BValue& element, + const absl::flat_hash_set& possible_values) -> BValue { + std::vector possible_values_vec(possible_values.begin(), + possible_values.end()); + absl::c_sort(possible_values_vec, [](const Bits& a, const Bits& b) { + return bits_ops::ULessThan(a, b); + }); + std::vector checks; + for (const Bits& possible_value : possible_values_vec) { + checks.push_back(fb.Eq(element, fb.Literal(possible_value))); + } + if (checks.size() == 1) { + return checks[0]; + } + CHECK(!checks.empty()); + return fb.Or(checks); + }); + fb.And(checks.elements()); + XLS_ASSIGN_OR_RETURN(Function * check_func, fb.Build()); + XLS_RETURN_IF_ERROR(package->SetTop(check_func)); + OptimizationCompoundPass pass("check_pass", "check_pass"); + pass.Add(); + pass.Add(); + PassResults pass_res; + OptimizationContext ctx; + XLS_RETURN_IF_ERROR(pass.Run(package.get(), {}, &pass_res, ctx).status()); + return absl::OkStatus(); +} + +auto WithResultCheck(fuzztest::Domain> domain) { + return fuzztest::Map( + [](std::shared_ptr pkg) -> std::shared_ptr { + CHECK_OK(AddResultChecks(pkg)); + return pkg; + }, + fuzztest::Filter( + [](std::shared_ptr pkg) { + return pkg->GetTopAsFunction() + .value() + ->return_type() + ->GetFlatBitCount() != 0; + }, + domain)); +} +} // namespace + +TEST(IrFuzzTest, CheckWithResultCheck) { + std::shared_ptr package = std::make_shared("test"); + FunctionBuilder fb("test", package.get()); + BValue x = fb.Concat({fb.Literal(UBits(0, 2)), + fb.Param("x", package->GetBitsType(3)), + fb.Literal(UBits(0, 2))}); + BValue y = fb.ZeroExtend(fb.Param("y", package->GetBitsType(3)), 7); + fb.Add(x, y); + XLS_ASSERT_OK_AND_ASSIGN(Function * f, fb.Build()); + XLS_ASSERT_OK(package->SetTop(f)); + XLS_ASSERT_OK(AddResultChecks(package)); + RecordProperty("package", package->DumpIr()); + Z3TranslationTest(package); +} +FUZZ_TEST(IrFuzzTest, Z3TranslationTest) + .WithDomains(WithResultCheck( + PackageDomainBuilder() + .NoDefineFunction() + .NoInvoke() + .NoClz() + .NoCtz() + .WithParamBits(20) + .WithoutOperations( + {// We might try to look at the tuple elements which means that + // we can't actually just use the interpreter value since that + // only gives one of the possible pairs that could be used. + Op::kSMulp, Op::kUMulp, + // Not handled for translation + Op::kCover, Op::kAssert, Op::kTrace}) + .WithCombineListMethod(CombineListMethod::TUPLE_LIST_METHOD) + .Build())); + } // namespace } // namespace xls