From 5d7703cee3c5eb24a69568813b43eda9dff54111 Mon Sep 17 00:00:00 2001 From: Eric Astor Date: Wed, 8 Apr 2026 14:52:11 -0700 Subject: [PATCH] [opt] Add simplification for division by a bounded set of constants. When a division's divisor is known to take values only from a small set of constants (or even a large set, if almost all of them are powers of two), we replace the expensive variable-divisor division with a select tree over constant divisions. This avoids expensive hardware dividers by allowing subsequent passes to reduce the constant divisions into multiplications or shifts. This provides a significant area and/or latency win for cases where the set of possible divisors is small - and gives us more opportunity to let users express division naturally without getting a full hardware divider. PiperOrigin-RevId: 896714610 --- .../division_simplification_design.md | 147 ++++++ mkdocs.yml | 3 + xls/passes/BUILD | 53 ++ xls/passes/optimization_pass_pipeline.txtpb | 2 + xls/passes/value_set_simplification_pass.cc | 460 ++++++++++++++++++ xls/passes/value_set_simplification_pass.h | 44 ++ .../value_set_simplification_pass_test.cc | 347 +++++++++++++ 7 files changed, 1056 insertions(+) create mode 100644 docs_src/design_docs/optimizations/value_set_simp/division_simplification_design.md create mode 100644 xls/passes/value_set_simplification_pass.cc create mode 100644 xls/passes/value_set_simplification_pass.h create mode 100644 xls/passes/value_set_simplification_pass_test.cc diff --git a/docs_src/design_docs/optimizations/value_set_simp/division_simplification_design.md b/docs_src/design_docs/optimizations/value_set_simp/division_simplification_design.md new file mode 100644 index 0000000000..b331894df8 --- /dev/null +++ b/docs_src/design_docs/optimizations/value_set_simp/division_simplification_design.md @@ -0,0 +1,147 @@ +# Division Simplification by Selector and QueryEngine + +Users often use division where the divisor takes a small set of values (e.g., +powers of two or a selection of constants). This document outlines the design +for simplifying such divisions to avoid expensive hardware dividers. + +## Design + +The optimization is implemented in `value_set_simplification_pass.cc`. It uses +the `PartialInfoQueryEngine` (retrieved via +`context.SharedQueryEngine(f)`) to evaluate the set of +possible values for the divisor `Y` in a division `div(x, Y)`. + +`PartialInfoQueryEngine` is preferred over standard ternary analysis because it +provides **interval/range information**. Range information is critical for +division; for example, if we know `Y` is in `[2, 4]`, we know it takes at most 3 +values, which ternary might miss if multiple bits are toggling. + +### Rule 1: All Possible Values are Powers of Two (or Zero) + +This rule applies when all possible values of `Y` are powers of two, or zero. +Let `K` be the number of distinct values `Y` can assume. + +XLS semantics define division by 0 as: If the divisor is zero, unsigned division +produces a maximal positive value. For signed division, if the divisor is zero +the result is the maximal positive value if the dividend is non-negative or the +maximal negative value if the dividend is negative. + +As such, zero is like a power of two in that dividing by it is close to free. + +Since we always prioritize area: + +- **Constant Shifts Tree**: Used if `K <= log2(M) + 1` (where `M` is max shift + amount). +- **Variable Shift Fallback**: Used if `K > log2(M) + 1`. + - **Zero Guard**: Rewrite to `Y == 0 ? (appropriate value) : x >> + Encode(Y)`. + +!!! NOTE + **Property Checks for Powers of Two**: To verify Rule 1, we scan the + `IntervalSet` to count powers of two and identify if they cover all values. + +!!! NOTE + **Signed Division**: This is slightly more complicated for `SDiv`, as the + dividend and divisor can each be negative. If the divisor is negative, we divide + by its absolute value and negate the quotient. If the dividend is negative, we + need to handle the fact that arithmetic shift-right rounds towards negative + infinity rather than towards 0. To fix this, we have to bias the dividend by + adding `Y - 1` before shifting - but **only** if the dividend is negative. + +### Rule 2: Some Possible Values are NOT Powers of Two + +Let `L` be the count of non-power-of-two constants. + +**How Sinking Works**: We rewrite the single division `div(x, Y)` into a +`priority_sel` over the possible constant values of `Y`. The branches of the +select become `div(x, C_1)`, `div(x, C_2)`, etc. This replaces a +variable-divisor division with multiple constant-divisor divisions. We rely on +the existing `arith_simplification_pass.cc` (which runs in the same pipeline) to +recognize `div(x, Constant)` and replace it with a multiply-and-shift using the +reciprocal. We do **not** implement the reciprocal multiplication logic here! + +- **If an Area Model is available**: Query the area of a single multiplication + of size `N` and the muxes. If it is cheaper than a divider, sink it. +- **If no Area Model is present**: Fallback to a safe universal limit of `L <= + 2` (replaces a divider with at most two multipliers, which is neutral area + but a huge latency win). + +**Caveats and Edge Cases**: + +- **Skip Single Literals**: If `Y` is a single known literal constant, abort + early. Let `arith_simplification_pass.cc` handle it natively. +- **Division By Zero**: If the set of constants contains 0, do not emit + `div(x, 0)`. Instead, emit the XLS standard division-by-zero value directly + for that branch (e.g., all bits set to 1). + +### The Hybrid Case (Powers of Two + General Constants) + +When we have a mix of powers of two (`K_p2` of them) and general +non-power-of-two constants (`L` of them), we have two choices for the powers of +two: + +- **Option A (Separate)**: Keep powers of two as individual constants shifts. + Total cases: `K_p2 + L`. +- **Option B (Grouped)**: Group all powers of two into a single variable shift + case. Total cases: `L + 1`. + +**Decision Rule (Consistent with Rule 1)**: + +- **If an Area Model is available**: Directly compare options A and B using + the area estimator. +- **If no Area Model is present**: Fallback to threshold `K_p2 > log2(M) + C` + (where `C = 1` for `UDiv`, `C = 2` for `SDiv`). + +**Final Select Cardinality (`C_eff`)**: + +- For Option A: `C_eff = K_p2 + L` +- For Option B: `C_eff = L + 1` + +**Profitability Sinking Rules**: + +- **If an Area Model is available**: Query the area of the chosen approach vs + the divider. +- **If no Area Model is present**: Limit to `L <= 2` non-powers-of-two + constants for Option A, or `L <= 1` for Option B. + +## Implementation Phases + +To ensure a smooth and incremental rollout, we will split the implementation +into three phases: + +### Phase 1: Rule 1 (Powers of Two Check) + +Implement Rule 1 using `PartialInfoQueryEngine`. Use `AtMostBitOneTrue` to +detect powers of two and `Op::kEncode` to calculate shift amounts. + +### Phase 2: Rule 2 (Implicit Select Sinking) + +Implement Rule 2 using `PartialInfoQueryEngine`. + +- Extract sets of constants from small intervals in `IntervalSet`. +- Synthesize a select tree. +- Fallback to `L <= 2` if no Area Model is available. + +### Phase 3: Hybrid Cases + +Merge the powers-of-two support and general constant support into the final +decision rule (Option A vs Option B). + +-------------------------------------------------------------------------------- + +## Alternatives Considered + +### 1. Simple Pattern Match in `arith_simplification_pass.cc` + +Brittle check for `div(x, sel(c, [constants]))`. + +- **Reason for Rejection**: Misses hidden selectors and non-immediate + constants. + +### 2. Generic Lifting in `select_lifting_pass.cc` + +Enable `UDiv`/`SDiv` for select lifting. + +- **Reason for Rejection**: Select Lifting pulls operations *out* of selects + (e.g., `sel(c, [x/1, x/2]) -> x / sel(c)`). This is the opposite of what we + want! We want **Select Sinking** (pushing the division into the select). diff --git a/mkdocs.yml b/mkdocs.yml index a3139689b2..0d4d993da4 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -79,6 +79,9 @@ nav: - Proc-scoped channels: 'design_docs/proc_scoped_channels.md' - Synchronous Procs: 'design_docs/synchronous_procs.md' - DSLX Pattern Exhaustiveness: 'design_docs/dslx_pattern_exhaustiveness.md' + - Optimizations: + - Value Set Simplification: + - Division: 'design_docs/optimizations/value_set_simp/division_simplification_design.md' - Releasing: 'releasing.md' - NoC: - Overview: 'noc/xls_noc_readme.md' diff --git a/xls/passes/BUILD b/xls/passes/BUILD index fd10c2bf05..476d3ac01a 100644 --- a/xls/passes/BUILD +++ b/xls/passes/BUILD @@ -869,6 +869,7 @@ xls_pass( "//xls/common/status:ret_check", "//xls/common/status:status_macros", "//xls/data_structures:leaf_type_tree", + "//xls/estimators/area_model:area_estimator", "//xls/interpreter:ir_interpreter", "//xls/ir", "//xls/ir:bits", @@ -3037,6 +3038,7 @@ cc_test( "//xls/common/fuzzing:fuzztest", "//xls/common/status:matchers", "//xls/common/status:status_macros", + "//xls/estimators/area_model:area_estimator", "//xls/fuzzer/ir_fuzzer:ir_fuzz_domain", "//xls/fuzzer/ir_fuzzer:ir_fuzz_test_library", "//xls/ir", @@ -4842,3 +4844,54 @@ cc_test( "@googletest//:gtest", ], ) + +xls_pass( + name = "value_set_simplification_pass", + srcs = ["value_set_simplification_pass.cc"], + hdrs = ["value_set_simplification_pass.h"], + pass_class = "ValueSetSimplificationPass", + deps = [ + ":optimization_pass", + ":partial_info_query_engine", + ":pass_base", + ":query_engine", + ":stateless_query_engine", + ":union_query_engine", + "//xls/common:math_util", + "//xls/common/status:status_macros", + "//xls/estimators/area_model:area_estimator", + "//xls/ir", + "//xls/ir:bits", + "//xls/ir:interval_set", + "//xls/ir:op", + "//xls/ir:value", + "@com_google_absl//absl/log", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/status:statusor", + ], +) + +cc_test( + name = "value_set_simplification_pass_test", + srcs = ["value_set_simplification_pass_test.cc"], + deps = [ + ":dce_pass", + ":optimization_pass", + ":pass_base", + ":value_set_simplification_pass", + "//xls/common:xls_gunit_main", + "//xls/common/status:matchers", + "//xls/common/status:status_macros", + "//xls/estimators/area_model:area_estimator", + "//xls/ir", + "//xls/ir:bits", + "//xls/ir:function_builder", + "//xls/ir:ir_matcher", + "//xls/ir:ir_test_base", + "//xls/ir:op", + "//xls/solvers:z3_ir_equivalence_testutils", + "@com_google_absl//absl/status:status_matchers", + "@com_google_absl//absl/status:statusor", + "@googletest//:gtest", + ], +) diff --git a/xls/passes/optimization_pass_pipeline.txtpb b/xls/passes/optimization_pass_pipeline.txtpb index 287e91574e..199e62aa0e 100644 --- a/xls/passes/optimization_pass_pipeline.txtpb +++ b/xls/passes/optimization_pass_pipeline.txtpb @@ -66,6 +66,8 @@ compound_passes: [ "dce", "dataflow", "dce", + "value_set_simp", + "dce", "strength_red", "dce", "array_simp", diff --git a/xls/passes/value_set_simplification_pass.cc b/xls/passes/value_set_simplification_pass.cc new file mode 100644 index 0000000000..a8b8e5a871 --- /dev/null +++ b/xls/passes/value_set_simplification_pass.cc @@ -0,0 +1,460 @@ +// Copyright 2026 The XLS Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "xls/passes/value_set_simplification_pass.h" + +#include +#include +#include +#include +#include + +#include "absl/log/check.h" +#include "absl/log/log.h" +#include "absl/status/statusor.h" +#include "xls/common/math_util.h" +#include "xls/common/status/status_macros.h" +#include "xls/estimators/area_model/area_estimator.h" +#include "xls/ir/bits.h" +#include "xls/ir/interval_set.h" +#include "xls/ir/node.h" +#include "xls/ir/nodes.h" +#include "xls/ir/op.h" +#include "xls/ir/value.h" +#include "xls/passes/optimization_pass.h" +#include "xls/passes/partial_info_query_engine.h" +#include "xls/passes/pass_base.h" +#include "xls/passes/query_engine.h" +#include "xls/passes/stateless_query_engine.h" +#include "xls/passes/union_query_engine.h" + +namespace xls { +namespace { + +class TemporaryNodeScope { + public: + explicit TemporaryNodeScope(FunctionBase* f) : f_(f) {} + ~TemporaryNodeScope() { + for (auto it = nodes_.rbegin(); it != nodes_.rend(); ++it) { + CHECK_OK(f_->RemoveNode(*it)); + } + } + + template + absl::StatusOr AddNode(Args&&... args) { + XLS_ASSIGN_OR_RETURN(Node * n, + f_->MakeNode(std::forward(args)...)); + nodes_.push_back(n); + return n; + } + + private: + FunctionBase* f_; + std::vector nodes_; +}; + +template +absl::StatusOr GetAreaForNode(const AreaEstimator* area_estimator, + FunctionBase* f, Args&&... args) { + TemporaryNodeScope temp_nodes(f); + XLS_ASSIGN_OR_RETURN(Node * temp, + temp_nodes.AddNode(std::forward(args)...)); + absl::StatusOr area = + area_estimator->GetOperationAreaInSquareMicrons(temp); + return area; +} + +absl::StatusOr CreateDivZeroResult(Node* node, Node* dividend) { + int64_t bit_width = dividend->BitCountOrDie(); + if (node->op() == Op::kUDiv) { + return node->function_base()->MakeNode( + node->loc(), Value(Bits::AllOnes(bit_width))); + } + XLS_ASSIGN_OR_RETURN(Node * msb, + node->function_base()->MakeNode( + node->loc(), dividend, bit_width - 1, 1)); + XLS_ASSIGN_OR_RETURN(Node * max_signed, + node->function_base()->MakeNode( + node->loc(), Value(Bits::MaxSigned(bit_width)))); + XLS_ASSIGN_OR_RETURN(Node * min_signed, + node->function_base()->MakeNode( + node->loc(), Value(Bits::MinSigned(bit_width)))); + return node->function_base()->MakeNode( + node->loc(), msb, std::vector{zero, sub}, + /*default_value=*/std::nullopt)); + XLS_ASSIGN_OR_RETURN( + Node * add, + temp_nodes.AddNode(node->loc(), dividend, bias_sel, Op::kAdd)); + + XLS_ASSIGN_OR_RETURN( + double sub_area, + area_estimator->GetOperationAreaInSquareMicrons(sub)); + XLS_ASSIGN_OR_RETURN( + double sel_area, + area_estimator->GetOperationAreaInSquareMicrons(bias_sel)); + XLS_ASSIGN_OR_RETURN( + double add_area, + area_estimator->GetOperationAreaInSquareMicrons(add)); + extra_sdiv_area = sub_area + sel_area + add_area; + } + + prefer_shift_by_encode = all_constants_select_area > + non_pow2_constants_select_area + encode_area + + shift_area + extra_sdiv_area; + + if (prefer_shift_by_encode) { + profitable = spec.non_pow2_constants_count * mul_area + + non_pow2_constants_select_area + encode_area + + shift_area + extra_sdiv_area < + div_area; + } else { + profitable = + spec.non_pow2_constants_count * mul_area + all_constants_select_area < + div_area; + } + } else { + prefer_shift_by_encode = + spec.pow2_constants_count > + FloorOfLog2(spec.max_pow2_shift) + (node->op() == Op::kSDiv ? 2 : 1); + profitable = prefer_shift_by_encode ? (spec.non_pow2_constants_count <= 1) + : (spec.non_pow2_constants_count <= 2); + } + + if (profitable) { + DivisionSimplificationSpec updated_spec = spec; + updated_spec.prefer_shift_by_encode = prefer_shift_by_encode; + return updated_spec; + } + return std::nullopt; +} + +absl::StatusOr ApplyMultipleConstantDivisionTransformation( + Node* node, const DivisionSimplificationSpec& spec) { + FunctionBase* f = node->function_base(); + Node* dividend = node->operand(0); + Node* divisor = node->operand(1); + int64_t bit_width = node->BitCountOrDie(); + + std::vector predicates; + std::vector cases; + Node* default_case = nullptr; + + std::vector values; + for (const Bits& v : spec.intervals.Values()) { + values.push_back(v); + } + + if (spec.prefer_shift_by_encode) { + for (const Bits& value : values) { + if (value.IsZero() || value.IsPowerOfTwo()) { + continue; + } + XLS_ASSIGN_OR_RETURN(Node * value_lit, + f->MakeNode(node->loc(), Value(value))); + XLS_ASSIGN_OR_RETURN( + Node * eq, + f->MakeNode(node->loc(), divisor, value_lit, Op::kEq)); + predicates.push_back(eq); + XLS_ASSIGN_OR_RETURN( + Node * div_case, + f->MakeNode(node->loc(), dividend, value_lit, node->op())); + cases.push_back(div_case); + } + + XLS_ASSIGN_OR_RETURN(Node * shift_amt, + f->MakeNode(node->loc(), divisor)); + if (node->op() == Op::kUDiv) { + XLS_ASSIGN_OR_RETURN( + Node * shifted, + f->MakeNode(node->loc(), dividend, shift_amt, Op::kShrl)); + default_case = shifted; + } else { + // SDiv requires a bias for negative numbers to truncate towards zero + // rather than negative infinity (which is what Shra does). + // Bias formula: (dividend < 0 ? divisor - 1 : 0). + // + // Note: we can also compute divisor - 1 as ~(all_ones << shift_amt). + // We choose subtraction over shifting because in hardware, an adder (Sub) + // is usually smaller than a barrel shifter (Shll) for large widths. + XLS_ASSIGN_OR_RETURN( + Node * msb, + f->MakeNode(node->loc(), dividend, bit_width - 1, 1)); + XLS_ASSIGN_OR_RETURN( + Node * one, + f->MakeNode(node->loc(), Value(UBits(1, bit_width)))); + XLS_ASSIGN_OR_RETURN( + Node * divisor_minus_one, + f->MakeNode(node->loc(), divisor, one, Op::kSub)); + XLS_ASSIGN_OR_RETURN( + Node * zero, + f->MakeNode(node->loc(), Value(UBits(0, bit_width)))); + XLS_ASSIGN_OR_RETURN( + Node * bias, + f->MakeNode(node->loc(), is_zero, + std::vector{default_case, div_zero_result}, + std::nullopt)); + } + } else { + for (const Bits& value : values) { + if (value.IsZero()) { + continue; + } + XLS_ASSIGN_OR_RETURN(Node * value_lit, + f->MakeNode(node->loc(), Value(value))); + XLS_ASSIGN_OR_RETURN( + Node * eq, + f->MakeNode(node->loc(), divisor, value_lit, Op::kEq)); + predicates.push_back(eq); + XLS_ASSIGN_OR_RETURN( + Node * div_case, + f->MakeNode(node->loc(), dividend, value_lit, node->op())); + cases.push_back(div_case); + } + + if (spec.intervals.CoversZero()) { + XLS_ASSIGN_OR_RETURN(default_case, CreateDivZeroResult(node, dividend)); + } else { + XLS_ASSIGN_OR_RETURN( + default_case, + f->MakeNode(node->loc(), Value(UBits(0, bit_width)))); + } + } + + if (predicates.empty()) { + return default_case; + } + + std::vector reversed_predicates(predicates.rbegin(), + predicates.rend()); + XLS_ASSIGN_OR_RETURN(Node * selector, + f->MakeNode(node->loc(), reversed_predicates)); + + return f->MakeNode(node->loc(), selector, cases, + default_case); +} + +absl::StatusOr TrySimplifyDivisionWithMultipleConstants( + Node* node, const QueryEngine& query_engine, + const AreaEstimator* area_estimator) { + XLS_ASSIGN_OR_RETURN( + std::optional spec, + CheckMultipleConstantDivisionApplicability(node, query_engine)); + if (!spec.has_value()) { + return false; + } + + XLS_ASSIGN_OR_RETURN( + std::optional updated_spec, + IsMultipleConstantDivisionProfitable(node, *spec, area_estimator)); + if (!updated_spec.has_value()) { + return false; + } + + XLS_ASSIGN_OR_RETURN( + Node * result, + ApplyMultipleConstantDivisionTransformation(node, *updated_spec)); + XLS_RETURN_IF_ERROR(node->ReplaceUsesWith(result)); + return true; +} + +} // namespace + +absl::StatusOr ValueSetSimplificationPass::RunOnFunctionBaseInternal( + FunctionBase* f, const OptimizationPassOptions& options, + PassResults* results, OptimizationContext& context) const { + auto query_engine = UnionQueryEngine::Of( + StatelessQueryEngine(), + GetSharedQueryEngine(context, f)); + + XLS_RETURN_IF_ERROR(query_engine.Populate(f).status()); + + bool modified = false; + for (Node* node : context.TopoSort(f)) { + if (node->op() == Op::kUDiv || node->op() == Op::kSDiv) { + XLS_ASSIGN_OR_RETURN(bool node_modified, + TrySimplifyDivisionWithMultipleConstants( + node, query_engine, options.area_estimator)); + modified |= node_modified; + } + } + return modified; +} + +} // namespace xls diff --git a/xls/passes/value_set_simplification_pass.h b/xls/passes/value_set_simplification_pass.h new file mode 100644 index 0000000000..cecc6dd513 --- /dev/null +++ b/xls/passes/value_set_simplification_pass.h @@ -0,0 +1,44 @@ +// Copyright 2026 The XLS Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef XLS_PASSES_VALUE_SET_SIMPLIFICATION_PASS_H_ +#define XLS_PASSES_VALUE_SET_SIMPLIFICATION_PASS_H_ + +#include + +#include "absl/status/statusor.h" +#include "xls/ir/function_base.h" +#include "xls/passes/optimization_pass.h" +#include "xls/passes/pass_base.h" + +namespace xls { + +// Simplifies operations where an operand can only take a small set of values +// (determined by range/interval analysis) into a select of cheaper operations. +class ValueSetSimplificationPass : public OptimizationFunctionBasePass { + public: + static constexpr std::string_view kName = "value_set_simp"; + explicit ValueSetSimplificationPass() + : OptimizationFunctionBasePass(kName, "Value Set Simplification") {} + ~ValueSetSimplificationPass() override = default; + + protected: + absl::StatusOr RunOnFunctionBaseInternal( + FunctionBase* f, const OptimizationPassOptions& options, + PassResults* results, OptimizationContext& context) const override; +}; + +} // namespace xls + +#endif // XLS_PASSES_VALUE_SET_SIMPLIFICATION_PASS_H_ diff --git a/xls/passes/value_set_simplification_pass_test.cc b/xls/passes/value_set_simplification_pass_test.cc new file mode 100644 index 0000000000..1e5bb0a7f4 --- /dev/null +++ b/xls/passes/value_set_simplification_pass_test.cc @@ -0,0 +1,347 @@ +// Copyright 2026 The XLS Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "xls/passes/value_set_simplification_pass.h" + +#include + +#include "gmock/gmock.h" +#include "gtest/gtest.h" +#include "absl/status/status_matchers.h" +#include "absl/status/statusor.h" +#include "xls/common/status/matchers.h" +#include "xls/common/status/status_macros.h" +#include "xls/estimators/area_model/area_estimator.h" +#include "xls/ir/bits.h" +#include "xls/ir/function.h" +#include "xls/ir/function_builder.h" +#include "xls/ir/ir_matcher.h" +#include "xls/ir/ir_test_base.h" +#include "xls/ir/lsb_or_msb.h" +#include "xls/ir/nodes.h" +#include "xls/ir/op.h" +#include "xls/ir/package.h" +#include "xls/passes/dce_pass.h" +#include "xls/passes/optimization_pass.h" +#include "xls/passes/pass_base.h" +#include "xls/solvers/z3_ir_equivalence_testutils.h" + +namespace m = ::xls::op_matchers; + +namespace xls { +namespace { + +using ::absl_testing::IsOkAndHolds; + +using ::xls::solvers::z3::ScopedVerifyEquivalence; + +class FakeAreaEstimator : public AreaEstimator { + public: + FakeAreaEstimator() : AreaEstimator("fake") {} + absl::StatusOr GetOperationAreaInSquareMicrons( + Node* node) const override { + if (node->op() == Op::kUMul || node->op() == Op::kSMul) { + return 10.0; + } + if (node->op() == Op::kUDiv || node->op() == Op::kSDiv) { + return 100.0; + } + return 1.0; + } + absl::StatusOr GetOneBitRegisterAreaInSquareMicrons() const override { + return 1.0; + } +}; + +class ValueSetSimplificationPassTest : public IrTestBase { + protected: + ValueSetSimplificationPassTest() = default; + + absl::StatusOr Run(Function* f) { + PassResults results; + OptimizationContext context; + XLS_ASSIGN_OR_RETURN(bool changed, + ValueSetSimplificationPass().RunOnFunctionBase( + f, OptimizationPassOptions(), &results, context)); + XLS_RETURN_IF_ERROR( + DeadCodeEliminationPass() + .RunOnFunctionBase(f, OptimizationPassOptions(), &results, context) + .status()); + return changed; + } +}; + +TEST_F(ValueSetSimplificationPassTest, UDivByPowerOfTwoVarShift) { + auto p = CreatePackage(); + FunctionBuilder fb(TestName(), p.get()); + BValue dividend = fb.Param("dividend", p->GetBitsType(4)); + BValue s = fb.Param("s", p->GetBitsType(3)); + BValue divisor = fb.OneHot(s, LsbOrMsb::kLsb); + fb.UDiv(dividend, divisor); + XLS_ASSERT_OK_AND_ASSIGN(Function * f, fb.Build()); + ScopedRecordIr sri(p.get()); + ScopedVerifyEquivalence sve(f); + ASSERT_THAT(Run(f), IsOkAndHolds(true)); + EXPECT_THAT(f->return_value(), + m::Shrl(m::Param("dividend"), + m::Encode(m::OneHot(m::Param("s"), LsbOrMsb::kLsb)))); +} + +TEST_F(ValueSetSimplificationPassTest, UDivByPowerOfTwoOrZeroVarShift) { + auto p = CreatePackage(); + FunctionBuilder fb(TestName(), p.get()); + BValue dividend = fb.Param("dividend", p->GetBitsType(4)); + BValue s = fb.Param("s", p->GetBitsType(3)); + BValue divisor = fb.Decode(s, /*width=*/4); + fb.UDiv(dividend, divisor); + XLS_ASSERT_OK_AND_ASSIGN(Function * f, fb.Build()); + ScopedRecordIr sri(p.get()); + ScopedVerifyEquivalence sve(f); + ASSERT_THAT(Run(f), IsOkAndHolds(true)); + EXPECT_THAT(f->return_value(), + m::Select(m::Eq(m::Decode(), m::Literal(0)), + {m::Shrl(m::Param("dividend"), m::Encode(m::Decode())), + m::Literal(Bits::AllOnes(4))})); +} + +TEST_F(ValueSetSimplificationPassTest, UDivByPowerOfTwoOrZeroConstShift) { + auto p = CreatePackage(); + FunctionBuilder fb(TestName(), p.get()); + BValue dividend = fb.Param("dividend", p->GetBitsType(4)); + BValue s = fb.Param("s", p->GetBitsType(2)); + BValue divisor = + fb.Concat({fb.Decode(s, /*width=*/2), fb.Literal(UBits(0, 2))}); + fb.UDiv(dividend, divisor); + XLS_ASSERT_OK_AND_ASSIGN(Function * f, fb.Build()); + ScopedRecordIr sri(p.get()); + ScopedVerifyEquivalence sve(f); + ASSERT_THAT(Run(f), IsOkAndHolds(true)); + EXPECT_THAT( + f->return_value(), + m::PrioritySelect(m::Concat(m::Eq(m::Concat(m::Decode(m::Param("s")), + m::Literal(UBits(0, 2))), + m::Literal(8)), + m::Eq(m::Concat(m::Decode(m::Param("s")), + m::Literal(UBits(0, 2))), + m::Literal(4))), + {m::UDiv(m::Param("dividend"), m::Literal(4)), + m::UDiv(m::Param("dividend"), m::Literal(8))}, + /*default_value=*/m::Literal(Bits::AllOnes(4)))); +} + +TEST_F(ValueSetSimplificationPassTest, UDivByFewConstantsFallback) { + auto p = CreatePackage(); + FunctionBuilder fb(TestName(), p.get()); + BValue dividend = fb.Param("dividend", p->GetBitsType(4)); + BValue s = fb.Param("s", p->GetBitsType(2)); + BValue divisor = + fb.Select(s, {fb.Literal(UBits(3, 4)), fb.Literal(UBits(5, 4)), + fb.Literal(UBits(3, 4)), fb.Literal(UBits(5, 4))}); + fb.UDiv(dividend, divisor); + XLS_ASSERT_OK_AND_ASSIGN(Function * f, fb.Build()); + ScopedRecordIr sri(p.get()); + ScopedVerifyEquivalence sve(f); + ASSERT_THAT(Run(f), IsOkAndHolds(true)); + EXPECT_THAT(f->return_value(), + m::PrioritySelect(m::Concat(m::Eq(m::Select(), m::Literal(5)), + m::Eq(m::Select(), m::Literal(3))), + {m::UDiv(m::Param("dividend"), m::Literal(3)), + m::UDiv(m::Param("dividend"), m::Literal(5))}, + /*default_value=*/m::Literal(0))); +} + +TEST_F(ValueSetSimplificationPassTest, UDivByMoreConstantsWithAreaEstimator) { + auto p = CreatePackage(); + FunctionBuilder fb(TestName(), p.get()); + BValue dividend = fb.Param("dividend", p->GetBitsType(4)); + BValue s = fb.Param("s", p->GetBitsType(2)); + BValue divisor = + fb.Select(s, {fb.Literal(UBits(3, 4)), fb.Literal(UBits(5, 4)), + fb.Literal(UBits(7, 4)), fb.Literal(UBits(3, 4))}); + fb.UDiv(dividend, divisor); + XLS_ASSERT_OK_AND_ASSIGN(Function * f, fb.Build()); + + FakeAreaEstimator fake_ae; + OptimizationPassOptions options; + options.area_estimator = &fake_ae; + PassResults results; + OptimizationContext context; + + ScopedRecordIr sri(p.get()); + ScopedVerifyEquivalence sve(f); + ASSERT_THAT(ValueSetSimplificationPass().RunOnFunctionBase(f, options, + &results, context), + IsOkAndHolds(true)); + EXPECT_THAT(f->return_value(), + m::PrioritySelect(m::Concat(m::Eq(m::Select(), m::Literal(7)), + m::Eq(m::Select(), m::Literal(5)), + m::Eq(m::Select(), m::Literal(3))), + {m::UDiv(m::Param("dividend"), m::Literal(3)), + m::UDiv(m::Param("dividend"), m::Literal(5)), + m::UDiv(m::Param("dividend"), m::Literal(7))}, + /*default_value=*/m::Literal(0))); +} + +TEST_F(ValueSetSimplificationPassTest, UDivByMoreConstantsHybridCase) { + auto p = CreatePackage(); + FunctionBuilder fb(TestName(), p.get()); + BValue dividend = fb.Param("dividend", p->GetBitsType(4)); + BValue s = fb.Param("s", p->GetBitsType(3)); + BValue divisor = + fb.Select(s, {fb.Literal(UBits(1, 4)), fb.Literal(UBits(2, 4)), + fb.Literal(UBits(4, 4)), fb.Literal(UBits(8, 4)), + fb.Literal(UBits(3, 4)), fb.Literal(UBits(3, 4)), + fb.Literal(UBits(3, 4)), fb.Literal(UBits(3, 4))}); + fb.UDiv(dividend, divisor); + XLS_ASSERT_OK_AND_ASSIGN(Function * f, fb.Build()); + + ScopedRecordIr sri(p.get()); + ScopedVerifyEquivalence sve(f); + ASSERT_THAT(Run(f), IsOkAndHolds(true)); + EXPECT_THAT( + f->return_value(), + m::PrioritySelect(m::Concat(m::Eq(m::Select(), m::Literal(3))), + {m::UDiv(m::Param("dividend"), m::Literal(3))}, + /*default_value=*/ + m::Shrl(m::Param("dividend"), m::Encode(m::Select())))); +} + +TEST_F(ValueSetSimplificationPassTest, SDivByFewConstantsFallback) { + auto p = CreatePackage(); + FunctionBuilder fb(TestName(), p.get()); + BValue dividend = fb.Param("dividend", p->GetBitsType(4)); + BValue s = fb.Param("s", p->GetBitsType(2)); + BValue divisor = + fb.Select(s, {fb.Literal(SBits(3, 4)), fb.Literal(SBits(5, 4)), + fb.Literal(SBits(3, 4)), fb.Literal(SBits(5, 4))}); + fb.SDiv(dividend, divisor); + XLS_ASSERT_OK_AND_ASSIGN(Function * f, fb.Build()); + + ScopedRecordIr sri(p.get()); + ScopedVerifyEquivalence sve(f); + ASSERT_THAT(Run(f), IsOkAndHolds(true)); + EXPECT_THAT(f->return_value(), + m::PrioritySelect(m::Concat(m::Eq(m::Select(), m::Literal(5)), + m::Eq(m::Select(), m::Literal(3))), + {m::SDiv(m::Param("dividend"), m::Literal(3)), + m::SDiv(m::Param("dividend"), m::Literal(5))}, + /*default_value=*/m::Literal(0))); +} + +TEST_F(ValueSetSimplificationPassTest, SDivByMoreConstantsHybridCase) { + auto p = CreatePackage(); + FunctionBuilder fb(TestName(), p.get()); + BValue dividend = fb.Param("dividend", p->GetBitsType(4)); + BValue s = fb.Param("s", p->GetBitsType(3)); + BValue divisor = + fb.Select(s, {fb.Literal(SBits(1, 4)), fb.Literal(SBits(2, 4)), + fb.Literal(SBits(4, 4)), fb.Literal(SBits(3, 4)), + fb.Literal(SBits(3, 4)), fb.Literal(SBits(3, 4)), + fb.Literal(SBits(3, 4)), fb.Literal(SBits(3, 4))}); + fb.SDiv(dividend, divisor); + XLS_ASSERT_OK_AND_ASSIGN(Function * f, fb.Build()); + + ScopedRecordIr sri(p.get()); + ScopedVerifyEquivalence sve(f); + ASSERT_THAT(Run(f), IsOkAndHolds(true)); + EXPECT_THAT(f->return_value(), + m::PrioritySelect(m::Concat(m::Eq(m::Select(), m::Literal(4)), + m::Eq(m::Select(), m::Literal(3)), + m::Eq(m::Select(), m::Literal(2)), + m::Eq(m::Select(), m::Literal(1))), + {m::SDiv(m::Param("dividend"), m::Literal(1)), + m::SDiv(m::Param("dividend"), m::Literal(2)), + m::SDiv(m::Param("dividend"), m::Literal(3)), + m::SDiv(m::Param("dividend"), m::Literal(4))}, + /*default_value=*/m::Literal(0))); +} + +TEST_F(ValueSetSimplificationPassTest, SDivByZeroFallback) { + auto p = CreatePackage(); + FunctionBuilder fb(TestName(), p.get()); + BValue dividend = fb.Param("dividend", p->GetBitsType(4)); + BValue s = fb.Param("s", p->GetBitsType(2)); + BValue divisor = + fb.Select(s, {fb.Literal(SBits(3, 4)), fb.Literal(SBits(5, 4)), + fb.Literal(SBits(0, 4)), fb.Literal(SBits(0, 4))}); + fb.SDiv(dividend, divisor); + XLS_ASSERT_OK_AND_ASSIGN(Function * f, fb.Build()); + ASSERT_THAT(Run(f), IsOkAndHolds(true)); + EXPECT_THAT( + f->return_value(), + m::PrioritySelect(m::Concat(m::Eq(m::Select(), m::Literal(5)), + m::Eq(m::Select(), m::Literal(3))), + {m::SDiv(m::Param("dividend"), m::Literal(3)), + m::SDiv(m::Param("dividend"), m::Literal(5))}, + /*default_value=*/ + m::Select(m::BitSlice(m::Param("dividend"), 3, 1), + {m::Literal(Bits::MaxSigned(4)), + m::Literal(Bits::MinSigned(4))}))); +} + +TEST_F(ValueSetSimplificationPassTest, + UDivByManyConstantsFallbackUnprofitable) { + auto p = CreatePackage(); + FunctionBuilder fb(TestName(), p.get()); + BValue dividend = fb.Param("dividend", p->GetBitsType(4)); + BValue s = fb.Param("s", p->GetBitsType(2)); + BValue divisor = + fb.Select(s, {fb.Literal(UBits(3, 4)), fb.Literal(UBits(5, 4)), + fb.Literal(UBits(7, 4)), fb.Literal(UBits(3, 4))}); + fb.UDiv(dividend, divisor); + XLS_ASSERT_OK_AND_ASSIGN(Function * f, fb.Build()); + ASSERT_THAT(Run(f), IsOkAndHolds(false)); +} + +class UnprofitableFakeAreaEstimator : public AreaEstimator { + public: + UnprofitableFakeAreaEstimator() : AreaEstimator("unprofitable_fake") {} + absl::StatusOr GetOperationAreaInSquareMicrons( + Node* node) const override { + if (node->op() == Op::kUMul || node->op() == Op::kSMul) { + return 1000.0; + } + if (node->op() == Op::kUDiv || node->op() == Op::kSDiv) { + return 10.0; + } + return 1.0; + } + absl::StatusOr GetOneBitRegisterAreaInSquareMicrons() const override { + return 1.0; + } +}; + +TEST_F(ValueSetSimplificationPassTest, UDivUnprofitableWithAreaEstimator) { + auto p = CreatePackage(); + FunctionBuilder fb(TestName(), p.get()); + BValue dividend = fb.Param("dividend", p->GetBitsType(4)); + BValue s = fb.Param("s", p->GetBitsType(2)); + BValue divisor = + fb.Select(s, {fb.Literal(UBits(3, 4)), fb.Literal(UBits(5, 4)), + fb.Literal(UBits(3, 4)), fb.Literal(UBits(5, 4))}); + fb.UDiv(dividend, divisor); + XLS_ASSERT_OK_AND_ASSIGN(Function * f, fb.Build()); + + UnprofitableFakeAreaEstimator fake_ae; + OptimizationPassOptions options; + options.area_estimator = &fake_ae; + PassResults results; + OptimizationContext context; + ASSERT_THAT(ValueSetSimplificationPass().RunOnFunctionBase(f, options, + &results, context), + IsOkAndHolds(false)); +} + +} // namespace +} // namespace xls