diff --git a/xls/solvers/BUILD b/xls/solvers/BUILD index 1c31f68e68..155c6aefb9 100644 --- a/xls/solvers/BUILD +++ b/xls/solvers/BUILD @@ -158,10 +158,17 @@ cc_test( ":z3_ir_equivalence_testutils", "//xls/common:xls_gunit_main", "//xls/common/status:matchers", + "//xls/common/status:ret_check", + "//xls/interpreter:ir_interpreter", "//xls/ir", + "//xls/ir:bits", + "//xls/ir:events", "//xls/ir:function_builder", "//xls/ir:ir_test_base", "//xls/ir:value", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/status", + "@com_google_absl//absl/types:span", "@googletest//:gtest", ], ) @@ -211,11 +218,18 @@ cc_test( ":z3_ir_translator_matchers", ":z3_utils", "//xls/common:xls_gunit_main", + "//xls/common/fuzzing:fuzztest", "//xls/common/status:matchers", + "//xls/common/status:ret_check", + "//xls/common/status:status_macros", + "//xls/data_structures:inline_bitmap", + "//xls/data_structures:leaf_type_tree", + "//xls/fuzzer/ir_fuzzer:ir_fuzz_domain", "//xls/interpreter:ir_interpreter", "//xls/ir", "//xls/ir:bits", "//xls/ir:bits_ops", + "//xls/ir:events", "//xls/ir:format_preference", "//xls/ir:function_builder", "//xls/ir:ir_parser", @@ -224,8 +238,18 @@ cc_test( "//xls/ir:source_location", "//xls/ir:value", "//xls/ir:value_builder", + "//xls/ir:value_flattening", + "//xls/ir:value_utils", + "//xls/passes:dfe_pass", + "//xls/passes:inlining_pass", + "//xls/passes:optimization_pass", + "//xls/passes:pass_base", + "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/cleanup", "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/log", + "@com_google_absl//absl/log:check", "@com_google_absl//absl/status", "@com_google_absl//absl/status:status_matchers", "@com_google_absl//absl/strings", diff --git a/xls/solvers/z3_ir_equivalence_testutils_test.cc b/xls/solvers/z3_ir_equivalence_testutils_test.cc index a9fd9ee5ca..4ad288eef8 100644 --- a/xls/solvers/z3_ir_equivalence_testutils_test.cc +++ b/xls/solvers/z3_ir_equivalence_testutils_test.cc @@ -14,8 +14,17 @@ #include "xls/solvers/z3_ir_equivalence_testutils.h" +#include + #include "gtest/gtest.h" +#include "absl/container/flat_hash_map.h" +#include "absl/status/status.h" +#include "absl/types/span.h" #include "xls/common/status/matchers.h" +#include "xls/common/status/ret_check.h" +#include "xls/interpreter/ir_interpreter.h" +#include "xls/ir/bits.h" +#include "xls/ir/events.h" #include "xls/ir/function_builder.h" #include "xls/ir/ir_test_base.h" #include "xls/ir/nodes.h" @@ -42,5 +51,47 @@ TEST_F(Z3IrEquivalenceTestutilsTest, EquivWithAssert) { XLS_ASSERT_OK(f->RemoveNode(add.node())); } +class FunctionInterpreter : public IrInterpreter { + public: + using IrInterpreter::IrInterpreter; + absl::Status HandleParam(Param* param) override { + XLS_RET_CHECK(HasResult(param)) << param; + return absl::OkStatus(); + } +}; + +TEST_F(Z3IrEquivalenceTestutilsTest, EquivArraySlice) { + auto p = CreatePackage(); + FunctionBuilder fb(TestName(), p.get()); + BValue x = fb.Param("x", p->GetArrayType(2, p->GetBitsType(4))); + BValue y = fb.Param("y", p->GetBitsType(4)); + BValue eq_0 = fb.Eq(y, fb.Literal(UBits(0, 4))); + fb.ArrayIndex(x, {fb.Literal(UBits(0, 4))}); + BValue element_1 = fb.ArrayIndex(x, {fb.Literal(UBits(1, 4))}); + BValue end = fb.Array({element_1, element_1}, element_1.GetType()); + BValue transformed = fb.Select(eq_0, /*on_true=*/x, /*on_false=*/end); + BValue res = fb.ArraySlice(x, y, 2); + XLS_ASSERT_OK_AND_ASSIGN(Function * f, fb.Build()); + ScopedVerifyEquivalence sve(f); + // Exhaustively check all inputs for equivalence. + for (uint64_t x0 = 0; x0 < 16; ++x0) { + for (uint64_t x1 = 0; x1 < 16; ++x1) { + for (uint64_t y0 = 0; y0 < 16; ++y0) { + InterpreterEvents events; + XLS_ASSERT_OK_AND_ASSIGN(Value x_val, Value::UBitsArray({x0, x1}, 4)); + absl::flat_hash_map node_values{ + {x.node(), x_val}, {y.node(), Value(UBits(y0, 4))}}; + FunctionInterpreter interp(&node_values, &events); + XLS_ASSERT_OK(f->Accept(&interp)); + EXPECT_EQ(node_values.at(res.node()), + node_values.at(transformed.node())) + << " @ [" << x0 << ", " << x1 << "], " << y; + } + } + } + XLS_ASSERT_OK( + res.node()->ReplaceImplicitUsesWith(transformed.node()).status()); +} + } // namespace } // namespace xls::solvers::z3 diff --git a/xls/solvers/z3_ir_translator.cc b/xls/solvers/z3_ir_translator.cc index 288ec012c4..dabe47bc85 100644 --- a/xls/solvers/z3_ir_translator.cc +++ b/xls/solvers/z3_ir_translator.cc @@ -926,13 +926,21 @@ absl::Status IrTranslator::HandleArraySlice(ArraySlice* array_slice) { Z3_ast start_ast = GetValue(array_slice->start()); ArrayType* input_type = array_slice->array()->GetType()->AsArrayOrDie(); ArrayType result_type(array_slice->width(), input_type->element_type()); - Z3_ast formatted_start_ast = - GetAsFormattedArrayIndex(ctx_, start_ast, input_type); - + int64_t min_offset_bits = Bits::MinBitCountUnsigned(array_slice->width()); + // Make sure we don't overflow. + int64_t offset_width = + 1 + std::max( + min_offset_bits, + Z3_get_bv_sort_size(ctx_, Z3_get_sort(ctx_, start_ast))); + Z3_sort index_sort = Z3_mk_bv_sort(ctx_, offset_width); + Z3_ast start_ext = Z3_mk_zero_ext( + ctx_, + offset_width - Z3_get_bv_sort_size(ctx_, Z3_get_sort(ctx_, start_ast)), + start_ast); std::vector elements; for (uint64_t i = 0; i < array_slice->width(); ++i) { - Z3_ast i_ast = GetAsFormattedArrayIndex(ctx_, i, input_type); - Z3_ast index_ast = Z3_mk_bvadd(ctx_, i_ast, formatted_start_ast); + Z3_ast i_ast = Z3_mk_int64(ctx_, i, index_sort); + Z3_ast index_ast = Z3_mk_bvadd(ctx_, start_ext, i_ast); elements.push_back(GetArrayElement(input_type, array_ast, index_ast)); } diff --git a/xls/solvers/z3_ir_translator_test.cc b/xls/solvers/z3_ir_translator_test.cc index 8b3b8e5202..7da319bc73 100644 --- a/xls/solvers/z3_ir_translator_test.cc +++ b/xls/solvers/z3_ir_translator_test.cc @@ -16,6 +16,7 @@ #include #include +#include #include #include #include @@ -25,17 +26,29 @@ #include "gmock/gmock.h" #include "gtest/gtest.h" +#include "xls/common/fuzzing/fuzztest.h" +#include "absl/algorithm/container.h" #include "absl/cleanup/cleanup.h" #include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" +#include "absl/log/check.h" +#include "absl/log/log.h" #include "absl/status/status.h" #include "absl/status/status_matchers.h" +#include "absl/strings/str_cat.h" #include "absl/strings/substitute.h" #include "absl/time/time.h" #include "absl/types/span.h" #include "xls/common/status/matchers.h" +#include "xls/common/status/ret_check.h" +#include "xls/common/status/status_macros.h" +#include "xls/data_structures/inline_bitmap.h" +#include "xls/data_structures/leaf_type_tree.h" +#include "xls/fuzzer/ir_fuzzer/ir_fuzz_domain.h" #include "xls/interpreter/ir_interpreter.h" #include "xls/ir/bits.h" #include "xls/ir/bits_ops.h" +#include "xls/ir/events.h" #include "xls/ir/format_preference.h" #include "xls/ir/function_builder.h" #include "xls/ir/ir_parser.h" @@ -47,6 +60,12 @@ #include "xls/ir/source_location.h" #include "xls/ir/value.h" #include "xls/ir/value_builder.h" +#include "xls/ir/value_flattening.h" +#include "xls/ir/value_utils.h" +#include "xls/passes/dfe_pass.h" +#include "xls/passes/inlining_pass.h" +#include "xls/passes/optimization_pass.h" +#include "xls/passes/pass_base.h" #include "xls/solvers/z3_ir_translator_matchers.h" #include "xls/solvers/z3_utils.h" #include "z3/src/api/z3.h" // IWYU pragma: keep @@ -2758,5 +2777,293 @@ TEST_F(Z3IrTranslatorTest, DumpWithNodeValues) { ContainsRegex("y: bits\\[32\\] id=[0-9]+ \\(0\\)"))); } +class FunctionInterpreter : public IrInterpreter { + public: + using IrInterpreter::IrInterpreter; + absl::Status HandleParam(Param* param) override { + XLS_RET_CHECK(HasResult(param)) << param; + return absl::OkStatus(); + } +}; + +TEST_F(Z3IrTranslatorTest, EquivArraySlice) { + auto p = CreatePackage(); + FunctionBuilder fb(TestName(), p.get()); + BValue x = fb.Param("x", p->GetArrayType(2, p->GetBitsType(4))); + BValue y = fb.Param("y", p->GetBitsType(4)); + BValue eq_0 = fb.Eq(y, fb.Literal(UBits(0, 4))); + BValue element_1 = fb.ArrayIndex(x, {fb.Literal(UBits(1, 4))}); + BValue end = fb.Array({element_1, element_1}, element_1.GetType()); + BValue transformed = fb.Select(eq_0, /*on_true=*/x, /*on_false=*/end); + BValue orig = fb.ArraySlice(x, y, 2); + fb.Eq(transformed, orig); + XLS_ASSERT_OK_AND_ASSIGN(Function * f, fb.Build()); + // Exhaustively check all inputs for equivalence. + for (uint64_t x0 = 0; x0 < 16; ++x0) { + for (uint64_t x1 = 0; x1 < 16; ++x1) { + for (uint64_t y0 = 0; y0 < 16; ++y0) { + InterpreterEvents events; + XLS_ASSERT_OK_AND_ASSIGN(Value x_val, Value::UBitsArray({x0, x1}, 4)); + absl::flat_hash_map node_values{ + {x.node(), x_val}, {y.node(), Value(UBits(y0, 4))}}; + FunctionInterpreter interp(&node_values, &events); + XLS_ASSERT_OK(f->Accept(&interp)); + EXPECT_EQ(node_values.at(orig.node()), + node_values.at(transformed.node())) + << " @ [" << x0 << ", " << x1 << "], " << y; + } + } + } + XLS_ASSERT_OK_AND_ASSIGN( + ProverResult res, + TryProve(f, f->return_value(), Predicate::NotEqualToZero(), + absl::InfiniteDuration())); + EXPECT_THAT(res, IsProvenTrue()) + << f->DumpIr(solvers::z3::CounterExampleAnnotator( + std::get(res))); + RecordProperty("smtlib", solvers::z3::EmitFunctionAsSmtLib(f).value_or( + "Unable to emit smtlib")); +} +TEST_F(Z3IrTranslatorTest, ArraySliceBigger) { + auto p = CreatePackage(); + FunctionBuilder fb(TestName(), p.get()); + BValue x = fb.Param("x", p->GetArrayType(2, p->GetBitsType(4))); + BValue y = fb.Param("y", p->GetBitsType(4)); + BValue element_1 = fb.ArrayIndex(x, {fb.Literal(UBits(1, 4))}); + BValue orig = fb.ArraySlice(x, y, 200); + fb.Eq(element_1, fb.ArrayIndex(orig, {fb.Add(fb.SignExtend(y, 8), + fb.Literal(UBits(32, 8)))})); + XLS_ASSERT_OK_AND_ASSIGN(Function * f, fb.Build()); + XLS_ASSERT_OK_AND_ASSIGN( + ProverResult res, + TryProve(f, f->return_value(), Predicate::NotEqualToZero(), + absl::InfiniteDuration())); + EXPECT_THAT(res, IsProvenTrue()) + << f->DumpIr(solvers::z3::CounterExampleAnnotator( + std::get(res))); + RecordProperty("smtlib", solvers::z3::EmitFunctionAsSmtLib(f).value_or( + "Unable to emit smtlib")); +} +TEST_F(Z3IrTranslatorTest, ArraySliceTranslate) { + auto p = CreatePackage(); + FunctionBuilder fb(TestName(), p.get()); + BValue x = fb.Param("x", p->GetArrayType(2, p->GetBitsType(4))); + BValue y = fb.Param("y", p->GetBitsType(4)); + BValue eq_0 = fb.Eq(y, fb.Literal(UBits(0, 4))); + BValue element_1 = fb.ArrayIndex(x, {fb.Literal(UBits(1, 4))}); + // BValue end = fb.Array({element_1, element_1}, element_1.GetType()); + BValue orig = fb.ArraySlice(x, y, 2); + BValue ret_val = fb.Or( + fb.And(fb.Eq(element_1, fb.ArrayIndex(orig, {fb.Literal(UBits(0, 1))})), + fb.Eq(element_1, fb.ArrayIndex(orig, {fb.Literal(UBits(1, 1))}))), + eq_0); + XLS_ASSERT_OK_AND_ASSIGN(Function * f, fb.Build()); + // Exhaustively check all inputs for equivalence. + for (uint64_t x0 = 0; x0 < 16; ++x0) { + for (uint64_t x1 = 0; x1 < 16; ++x1) { + for (uint64_t y0 = 0; y0 < 16; ++y0) { + InterpreterEvents events; + XLS_ASSERT_OK_AND_ASSIGN(Value x_val, Value::UBitsArray({x0, x1}, 4)); + absl::flat_hash_map node_values{ + {x.node(), x_val}, {y.node(), Value(UBits(y0, 4))}}; + FunctionInterpreter interp(&node_values, &events); + XLS_ASSERT_OK(f->Accept(&interp)); + EXPECT_EQ(node_values.at(ret_val.node()), Value(UBits(1, 1))) + << " @ [" << x0 << ", " << x1 << "], " << y; + } + } + } + XLS_ASSERT_OK_AND_ASSIGN( + ProverResult res, + TryProve(f, f->return_value(), Predicate::NotEqualToZero(), + absl::InfiniteDuration())); + EXPECT_THAT(res, IsProvenTrue()) + << f->DumpIr(solvers::z3::CounterExampleAnnotator( + std::get(res))); + RecordProperty("smtlib", solvers::z3::EmitFunctionAsSmtLib(f).value_or( + "Unable to emit smtlib")); +} +TEST_F(Z3IrTranslatorTest, EquivArrayIndex) { + auto p = CreatePackage(); + FunctionBuilder fb(TestName(), p.get()); + BValue x = fb.Param("x", p->GetArrayType(2, p->GetBitsType(4))); + BValue y = fb.Param("y", p->GetBitsType(4)); + BValue eq_0 = fb.Eq(y, fb.Literal(UBits(0, 4))); + BValue element_0 = fb.ArrayIndex(x, {fb.Literal(UBits(0, 4))}); + BValue element_1 = fb.ArrayIndex(x, {fb.Literal(UBits(1, 4))}); + BValue transformed = + fb.Select(eq_0, /*on_true=*/element_0, /*on_false=*/element_1); + BValue orig = fb.ArrayIndex(x, {y}); + fb.Eq(transformed, orig); + XLS_ASSERT_OK_AND_ASSIGN(Function * f, fb.Build()); + // Exhaustively check all inputs for equivalence. + for (uint64_t x0 = 0; x0 < 16; ++x0) { + for (uint64_t x1 = 0; x1 < 16; ++x1) { + for (uint64_t y0 = 0; y0 < 16; ++y0) { + InterpreterEvents events; + XLS_ASSERT_OK_AND_ASSIGN(Value x_val, Value::UBitsArray({x0, x1}, 4)); + absl::flat_hash_map node_values{ + {x.node(), x_val}, {y.node(), Value(UBits(y0, 4))}}; + FunctionInterpreter interp(&node_values, &events); + XLS_ASSERT_OK(f->Accept(&interp)); + EXPECT_EQ(node_values.at(orig.node()), + node_values.at(transformed.node())) + << " @ [" << x0 << ", " << x1 << "], " << y; + } + } + } + XLS_ASSERT_OK_AND_ASSIGN( + ProverResult res, + TryProve(f, f->return_value(), Predicate::NotEqualToZero(), + absl::InfiniteDuration())); + EXPECT_THAT(res, IsProvenTrue()) + << f->DumpIr(solvers::z3::CounterExampleAnnotator( + std::get(res))); + RecordProperty("smtlib", solvers::z3::EmitFunctionAsSmtLib(f).value_or( + "Unable to emit smtlib")); +} + +void Z3TranslationTest(std::shared_ptr package) { + // Get all the possible values this function can return. + // FunctionInterpreter interp(); + XLS_ASSERT_OK_AND_ASSIGN(Function * f, package->GetTopAsFunction()); + XLS_ASSERT_OK_AND_ASSIGN( + ProverResult res, + TryProve(f, f->return_value(), Predicate::NotEqualToZero(), + absl::InfiniteDuration())); + EXPECT_THAT(res, IsProvenTrue()) + << f->DumpIr(solvers::z3::CounterExampleAnnotator( + std::get(res))); +} +namespace { +absl::Status AddResultChecks(std::shared_ptr package) { + XLS_ASSIGN_OR_RETURN(Function * func, package->GetTopAsFunction()); + LeafTypeTree> possible_return_values( + func->return_type(), absl::flat_hash_set{}); + int64_t num_inputs = 0; + for (Param* param : func->params()) { + num_inputs += param->GetType()->GetFlatBitCount(); + } + CHECK_LE(num_inputs, 20) << func->DumpIr(); + for (int64_t i = 0; i < (1 << num_inputs); ++i) { + auto bits = UBits(i, num_inputs); + auto it = bits.begin(); + absl::flat_hash_map node_values; + for (Param* param : func->params()) { + InlineBitmap value_bits(param->GetType()->GetFlatBitCount()); + for (int j = 0; j < param->GetType()->GetFlatBitCount(); ++j) { + value_bits.Set(j, *it); + ++it; + } + XLS_ASSIGN_OR_RETURN( + Value val, + UnflattenBitsToValue(Bits::FromBitmap(std::move(value_bits)), + param->GetType())); + node_values[param] = val; + } + FunctionInterpreter interp(&node_values, nullptr); + XLS_RETURN_IF_ERROR(func->Accept(&interp)); + XLS_ASSIGN_OR_RETURN( + LeafTypeTree ltt, + ValueToBitsLeafTypeTree(node_values.at(func->return_value()), + func->return_type())); + XLS_RETURN_IF_ERROR( + (leaf_type_tree::UpdateFrom, Bits>( + possible_return_values.AsMutableView(), ltt.AsView(), + [](Type*, absl::flat_hash_set& possible_values, + const Bits& element, absl::Span) { + possible_values.insert(element); + return absl::OkStatus(); + }))); + } + LeafTypeTree> possible_return_values_set; + FunctionBuilder fb(absl::StrCat("check_", func->name()), package.get()); + std::vector params; + for (Param* param : func->params()) { + params.push_back(fb.Param(param->name(), param->GetType())); + } + BValue res = fb.Invoke(params, func); + LeafTypeTree return_values = fb.MakeLeafTypeTree(res); + LeafTypeTree checks = + leaf_type_tree::Zip>( + return_values.AsView(), possible_return_values.AsView(), + [&](const BValue& element, + const absl::flat_hash_set& possible_values) -> BValue { + std::vector possible_values_vec(possible_values.begin(), + possible_values.end()); + absl::c_sort(possible_values_vec, [](const Bits& a, const Bits& b) { + return bits_ops::ULessThan(a, b); + }); + std::vector checks; + for (const Bits& possible_value : possible_values_vec) { + checks.push_back(fb.Eq(element, fb.Literal(possible_value))); + } + if (checks.size() == 1) { + return checks[0]; + } + CHECK(!checks.empty()); + return fb.Or(checks); + }); + fb.And(checks.elements()); + XLS_ASSIGN_OR_RETURN(Function * check_func, fb.Build()); + XLS_RETURN_IF_ERROR(package->SetTop(check_func)); + OptimizationCompoundPass pass("check_pass", "check_pass"); + pass.Add(); + pass.Add(); + PassResults pass_res; + OptimizationContext ctx; + XLS_RETURN_IF_ERROR(pass.Run(package.get(), {}, &pass_res, ctx).status()); + return absl::OkStatus(); +} + +auto WithResultCheck(fuzztest::Domain> domain) { + return fuzztest::Map( + [](std::shared_ptr pkg) -> std::shared_ptr { + CHECK_OK(AddResultChecks(pkg)); + return pkg; + }, + fuzztest::Filter( + [](std::shared_ptr pkg) { + return pkg->GetTopAsFunction() + .value() + ->return_type() + ->GetFlatBitCount() != 0; + }, + domain)); +} +} // namespace + +TEST(IrFuzzTest, CheckWithResultCheck) { + std::shared_ptr package = std::make_shared("test"); + FunctionBuilder fb("test", package.get()); + BValue x = fb.Concat({fb.Literal(UBits(0, 2)), + fb.Param("x", package->GetBitsType(3)), + fb.Literal(UBits(0, 2))}); + BValue y = fb.ZeroExtend(fb.Param("y", package->GetBitsType(3)), 7); + fb.Add(x, y); + XLS_ASSERT_OK_AND_ASSIGN(Function * f, fb.Build()); + XLS_ASSERT_OK(package->SetTop(f)); + XLS_ASSERT_OK(AddResultChecks(package)); + RecordProperty("package", package->DumpIr()); + Z3TranslationTest(package); +} +FUZZ_TEST(IrFuzzTest, Z3TranslationTest) + .WithDomains(WithResultCheck( + PackageDomainBuilder() + .NoDefineFunction() + .NoInvoke() + .NoClz() + .NoCtz() + .WithParamBits(20) + .WithoutOperations( + {// We might try to look at the tuple elements which means that + // we can't actually just use the interpreter value since that + // only gives one of the possible pairs that could be used. + Op::kSMulp, Op::kUMulp, + // Not handled for translation + Op::kCover, Op::kAssert, Op::kTrace}) + .WithCombineListMethod(CombineListMethod::TUPLE_LIST_METHOD) + .Build())); + } // namespace } // namespace xls