From 363f3a27c50ef6474ccda9ef325f3ce2051597da Mon Sep 17 00:00:00 2001 From: Nelson Liang Date: Wed, 18 Mar 2026 13:23:46 -0700 Subject: [PATCH] [Explicit State Access] Allow next_values to use state_element instead of state_read. PiperOrigin-RevId: 885757168 --- xls/ir/function.cc | 2 +- xls/ir/function_base.cc | 48 ++++++++++++++++++++++++++++++++--------- xls/ir/function_base.h | 33 +++++++--------------------- xls/ir/node.cc | 10 +++++++-- xls/ir/nodes.cc | 13 ++++++----- xls/ir/nodes.h | 23 +++++++------------- xls/ir/proc.cc | 18 ++++++++++------ xls/ir/proc_test.cc | 39 +++++++++++++++++++++++++++++++++ xls/ir/verify_node.cc | 41 ++++++++++++++++++++++------------- 9 files changed, 148 insertions(+), 79 deletions(-) diff --git a/xls/ir/function.cc b/xls/ir/function.cc index 56784e15be..657093bc0a 100644 --- a/xls/ir/function.cc +++ b/xls/ir/function.cc @@ -329,7 +329,7 @@ absl::Status Function::InternalRebuildSideTables() { // only held in the side table. We can still check for correctness at least. // TODO(allight): We should ideally be able to do this. XLS_RET_CHECK(next_values_.empty()); - XLS_RET_CHECK(next_values_by_state_read_.empty()); + XLS_RET_CHECK(next_values_by_state_element_.empty()); for (Param* p : params_) { XLS_RET_CHECK(p->function_base() == this) diff --git a/xls/ir/function_base.cc b/xls/ir/function_base.cc index 67805781be..f341b4289b 100644 --- a/xls/ir/function_base.cc +++ b/xls/ir/function_base.cc @@ -29,6 +29,7 @@ #include "absl/algorithm/container.h" #include "absl/base/casts.h" +#include "absl/base/no_destructor.h" #include "absl/container/btree_set.h" #include "absl/container/flat_hash_map.h" #include "absl/container/flat_hash_set.h" @@ -245,7 +246,8 @@ void FunctionBase::TakeOwnershipOfNode(std::unique_ptr&& node) { FunctionBase* old_owner = node->function_base(); if (node->Is()) { - old_owner->next_values_by_state_read_.erase(node->As()); + old_owner->next_values_by_state_element_.erase( + node->As()->state_element()); } old_owner->node_iterators_.erase(node.get()); @@ -442,6 +444,34 @@ absl::StatusOr FunctionBase::GetNode( absl::StrFormat("GetNode(%s) failed.", standard_node_name)); } +const absl::btree_set& FunctionBase::next_values( + StateRead* state_read) const { + return next_values(state_read->state_element()); +} + +const absl::btree_set& FunctionBase::next_values( + StateElement* state_element) const { + if (!next_values_by_state_element_.contains(state_element)) { + static const absl::NoDestructor< + absl::btree_set> + kEmptySet; + // This should be pretty rare. Basically this should only happen in the + // short time before the non-updated state element is replaced. Just + // returning what is actually there is nicer than crashing however. Do + // check that this is not just some sort of weird corruption however. + CHECK(absl::c_none_of(nodes(), + [state_element](Node* n) { + return n->Is() && + n->As()->state_element() == + state_element; + })) + << "Invalid side table for next values. Missing " << state_element + << " in " << this; + return *kEmptySet; + } + return next_values_by_state_element_.at(state_element); +} + absl::Status FunctionBase::RemoveNode(Node* node) { XLS_RET_CHECK(node->users().empty()) << node->GetName(); XLS_RET_CHECK(!HasImplicitUse(node)) << node->GetName(); @@ -462,13 +492,12 @@ absl::Status FunctionBase::RemoveNode(Node* node) { params_.end()); } if (node->Is()) { - next_values_by_state_read_.erase(node->As()); + next_values_by_state_element_.erase(node->As()->state_element()); } if (node->Is()) { Next* next = node->As(); - if (next->state_read()->Is()) { // Could've been replaced. - StateRead* state_read = next->state_read()->As(); - next_values_by_state_read_.at(state_read).erase(next); + if (next_values_by_state_element_.contains(next->state_element())) { + next_values_by_state_element_.at(next->state_element()).erase(next); } std::erase(next_values_, next); } @@ -559,13 +588,12 @@ Node* FunctionBase::AddNodeInternal(std::unique_ptr node) { params_.push_back(node->As()); } if (node->Is()) { - next_values_by_state_read_[node->As()]; + next_values_by_state_element_[node->As()->state_element()]; } if (node->Is()) { Next* next = node->As(); - StateRead* state_read = next->state_read()->As(); - next_values_.push_back(node->As()); - next_values_by_state_read_[state_read].insert(next); + next_values_.push_back(next); + next_values_by_state_element_[next->state_element()].insert(next); } Node* ptr = node.get(); node_iterators_[ptr] = nodes_.insert(nodes_.end(), std::move(node)); @@ -683,7 +711,7 @@ absl::Status FunctionBase::RebuildSideTables() { // TODO(allight): The fact that there is so much crap in the function_base // itself is a problem. Having next's and params' in the function base doesn't // make a ton of sense. - // NB Because of above the next-values/next_values_by_state_read_ and params + // NB Because of above the next-values/next_values_by_state_element_ and // lists are updated in proc and function respectively. // NB We assume that node_iterators_ never gets invalidated. XLS_RETURN_IF_ERROR(InternalRebuildSideTables()); diff --git a/xls/ir/function_base.h b/xls/ir/function_base.h index 76f3fc391e..ab19e86fd6 100644 --- a/xls/ir/function_base.h +++ b/xls/ir/function_base.h @@ -26,8 +26,6 @@ #include #include -#include "absl/algorithm/container.h" -#include "absl/base/no_destructor.h" #include "absl/base/optimization.h" #include "absl/container/btree_set.h" #include "absl/container/flat_hash_map.h" @@ -249,26 +247,10 @@ class FunctionBase { absl::Span next_values() const { return next_values_; } const absl::btree_set& next_values( - StateRead* state_read) const { - if (!next_values_by_state_read_.contains(state_read)) { - // This should be pretty rare. Basically this should only happen in the - // short time before the non-updated state element is replaced. Just - // returning what is actually there is nicer than crashing however. Do - // check that this is not just some sort of weird corruption however. - static const absl::NoDestructor< - absl::btree_set> - kEmptySet; - CHECK(absl::c_none_of(nodes(), - [state_read](Node* n) { - return n->Is() && - n->As()->state_read() == state_read; - })) - << "Invalid side table for next values. Missing " << state_read - << " in " << this; - return *kEmptySet; - } - return next_values_by_state_read_.at(state_read); - } + StateRead* state_read) const; + + const absl::btree_set& next_values( + StateElement* state_element) const; // Moves the given param to the given index in the parameter list. absl::Status MoveParamToIndex(Param* param, int64_t index); @@ -497,7 +479,7 @@ class FunctionBase { // together. return !n->Is() && !n->Is(); }); - other.next_values_by_state_read_.clear(); + other.next_values_by_state_element_.clear(); other.next_values_.clear(); } @@ -553,8 +535,9 @@ class FunctionBase { std::vector params_; std::vector next_values_; - absl::flat_hash_map> - next_values_by_state_read_; + absl::flat_hash_map> + next_values_by_state_element_; NameUniquer node_name_uniquer_ = NameUniquer(/*separator=*/"__", GetIrReservedWords()); diff --git a/xls/ir/node.cc b/xls/ir/node.cc index ecb21e8575..f5658077b4 100644 --- a/xls/ir/node.cc +++ b/xls/ir/node.cc @@ -640,8 +640,14 @@ std::string Node::ToStringInternal(bool include_operand_types) const { } case Op::kNext: { const Next* next = As(); - args = {absl::StrFormat("param=%s", next->state_read()->GetName()), - absl::StrFormat("value=%s", next->value()->GetName())}; + if (next->has_state_read_operand()) { + args = {absl::StrFormat("param=%s", next->state_read()->GetName()), + absl::StrFormat("value=%s", next->value()->GetName())}; + } else { + args = { + absl::StrFormat("state_element=%s", next->state_element()->name()), + absl::StrFormat("value=%s", next->value()->GetName())}; + } std::optional predicate = next->predicate(); if (predicate.has_value()) { args.push_back( diff --git a/xls/ir/nodes.cc b/xls/ir/nodes.cc index 763b4bfe03..641969fb1a 100755 --- a/xls/ir/nodes.cc +++ b/xls/ir/nodes.cc @@ -988,6 +988,7 @@ Next::Next(const SourceInfo& loc, StateElement* state_element, Node* value, : Node(Op::kNext, function->package()->GetTupleType({}), loc, name, function), state_element_(state_element), + has_state_read_operand_(false), has_predicate_(predicate.has_value()), predicate_operand_index_(1), label_(std::move(label)) { @@ -1002,6 +1003,8 @@ Next::Next(const SourceInfo& loc, Node* state_read, Node* value, std::string_view name, FunctionBase* function) : Node(Op::kNext, function->package()->GetTupleType({}), loc, name, function), + state_element_(state_read->As()->state_element()), + has_state_read_operand_(true), has_predicate_(predicate.has_value()), predicate_operand_index_(2), label_(std::move(label)) { @@ -1466,17 +1469,17 @@ absl::StatusOr StateRead::CloneInNewFunction( absl::StatusOr Next::CloneInNewFunction( absl::Span new_operands, FunctionBase* new_function) const { - if (state_element_ != nullptr) { + if (has_state_read_operand_) { return new_function->MakeNodeWithName( - loc(), state_element_, new_operands[0], - new_operands.size() > 1 ? std::make_optional(new_operands[1]) + loc(), new_operands[0], new_operands[1], + new_operands.size() > 2 ? std::make_optional(new_operands[2]) : std::nullopt, label(), GetNameView()); } // TODO(meheff): Choose an appropriate name for the cloned node. return new_function->MakeNodeWithName( - loc(), new_operands[0], new_operands[1], - new_operands.size() > 2 ? std::make_optional(new_operands[2]) + loc(), state_element_, new_operands[0], + new_operands.size() > 1 ? std::make_optional(new_operands[1]) : std::nullopt, label(), GetNameView()); } diff --git a/xls/ir/nodes.h b/xls/ir/nodes.h index 7cef2f05bd..b29b28550d 100755 --- a/xls/ir/nodes.h +++ b/xls/ir/nodes.h @@ -831,17 +831,12 @@ class Next final : public Node { absl::Span new_operands, FunctionBase* new_function) const final; Node* state_read() const { - CHECK(state_element_ == nullptr) << "StateElement is set"; + CHECK(has_state_read_operand_) + << "Next node does not have a state_read operand"; return operand(0); } - Node* value() const { - if (state_element_ != nullptr) { - return operand(0); - } else { - return operand(1); - } - } + Node* value() const { return operand(has_state_read_operand_ ? 1 : 0); } const std::optional& label() const { return label_; } @@ -885,15 +880,13 @@ class Next final : public Node { bool IsDefinitelyEqualTo(const Node* other) const final; - StateElement* state_element() const { - if (state_element_ != nullptr) { - return state_element_; - } - return state_read()->As()->state_element(); - } + bool has_state_read_operand() const { return has_state_read_operand_; } + + StateElement* state_element() const { return state_element_; } private: - StateElement* state_element_ = nullptr; + StateElement* state_element_; + bool has_state_read_operand_; bool has_predicate_; const int64_t predicate_operand_index_; std::optional label_; diff --git a/xls/ir/proc.cc b/xls/ir/proc.cc index 1d2e41c785..7d689e44c0 100644 --- a/xls/ir/proc.cc +++ b/xls/ir/proc.cc @@ -972,14 +972,20 @@ absl::StatusOr Proc::TransformStateElement( nt.old_next->GetName())); to_replace.push_back({nt.old_next, nxt}); // Identity-ify the old next. - XLS_RETURN_IF_ERROR(nt.old_next->ReplaceOperandNumber( - Next::kValueOperand, nt.old_next->state_read())); + if (nt.old_next->has_state_read_operand()) { + XLS_RETURN_IF_ERROR(nt.old_next->ReplaceOperandNumber( + Next::kValueOperand, nt.old_next->state_read())); + } else { + XLS_RETURN_IF_ERROR( + nt.old_next->ReplaceOperandNumber(/*operand_no=*/0, old_state_read)); + } } for (const auto& [old_n, new_n] : to_replace) { XLS_RETURN_IF_ERROR(old_n->ReplaceUsesWith( new_n, [&](Node* n) { - if (n->Is() && n->As()->state_read() == old_n) { + if (n->Is() && n->As()->has_state_read_operand() && + n->As()->state_read() == old_n) { return false; } return true; @@ -993,7 +999,7 @@ absl::Status Proc::InternalRebuildSideTables() { XLS_RET_CHECK(params_.empty()); // Why is next-values in base but not elements? next_values_.clear(); - next_values_by_state_read_.clear(); + next_values_by_state_element_.clear(); state_reads_.clear(); for (Node* n : nodes()) { if (n->Is()) { @@ -1003,8 +1009,8 @@ absl::Status Proc::InternalRebuildSideTables() { state_reads_[n->As()->state_element()] = n->As(); } else if (n->Is()) { next_values_.push_back(n->As()); - next_values_by_state_read_[n->As()->state_read()->As()] - .insert(n->As()); + next_values_by_state_element_[n->As()->state_element()].insert( + n->As()); } } // TODO(allight): We should make it so we can recover channel/proc-inst things diff --git a/xls/ir/proc_test.cc b/xls/ir/proc_test.cc index 8e382302c5..949421622e 100644 --- a/xls/ir/proc_test.cc +++ b/xls/ir/proc_test.cc @@ -205,6 +205,45 @@ TEST_F(ProcTest, StatelessProc) { EXPECT_EQ(proc->DumpIr(), "proc p() {\n}\n"); } +TEST_F(ProcTest, NextValuesByStateElement) { + auto p = CreatePackage(); + ProcBuilder pb("p", p.get()); + BValue state = pb.StateElement("st", Value(UBits(42, 32))); + BValue add = pb.Add(pb.Literal(UBits(1, 32)), state); + XLS_ASSERT_OK_AND_ASSIGN(Proc * proc, pb.Build({add})); + + StateElement* st_elem = proc->GetStateElement(0); + StateRead* st_read = proc->GetStateRead(st_elem); + + EXPECT_THAT(proc->next_values(st_elem), + ElementsAre(m::Next(m::StateRead("st"), m::Add()))); + EXPECT_THAT(proc->next_values(st_read), + ElementsAre(m::Next(m::StateRead("st"), m::Add()))); + + // Add another next value for the same state element using the StateElement + // constructor. + XLS_ASSERT_OK_AND_ASSIGN( + Node * literal_10, + proc->MakeNode(SourceInfo(), Value(UBits(10, 32)))); + XLS_ASSERT_OK_AND_ASSIGN( + Next * next2, + proc->MakeNode(SourceInfo(), st_elem, literal_10, + /*predicate=*/std::nullopt, /*label=*/std::nullopt)); + + EXPECT_THAT( + proc->next_values(st_elem), + UnorderedElementsAre(m::Next(m::StateRead("st"), m::Add()), next2)); + EXPECT_THAT( + proc->next_values(st_read), + UnorderedElementsAre(m::Next(m::StateRead("st"), m::Add()), next2)); + + XLS_ASSERT_OK(proc->RemoveNode(next2)); + EXPECT_THAT(proc->next_values(st_elem), + ElementsAre(m::Next(m::StateRead("st"), m::Add()))); + EXPECT_THAT(proc->next_values(st_read), + ElementsAre(m::Next(m::StateRead("st"), m::Add()))); +} + TEST_F(ProcTest, RemoveStateThatStillHasUse) { // Don't call CreatePackage which creates a VerifiedPackage because we // intentionally create a malformed proc. diff --git a/xls/ir/verify_node.cc b/xls/ir/verify_node.cc index 7931e869fe..3ec98894a6 100644 --- a/xls/ir/verify_node.cc +++ b/xls/ir/verify_node.cc @@ -978,29 +978,40 @@ class NodeChecker : public DfsVisitor { } absl::Status HandleNext(Next* next) override { - XLS_RETURN_IF_ERROR(ExpectOperandCountRange(next, 2, 3)); - if (!next->state_read()->Is()) { - return absl::InternalError( - absl::StrFormat("Next node %s expects a state read for param; is: %v", - next->GetName(), *next->state_read())); + if (next->has_state_read_operand()) { + XLS_RETURN_IF_ERROR(ExpectOperandCountRange(next, 2, 3)); + if (!next->state_read()->Is()) { + return absl::InternalError(absl::StrFormat( + "Next node %s expects a state read for param; is: %v", + next->GetName(), *next->state_read())); + } + } else { + XLS_RETURN_IF_ERROR(ExpectOperandCountRange(next, 1, 2)); } + if (next->predicate().has_value()) { - XLS_RETURN_IF_ERROR(ExpectOperandHasBitsType(next, /*operand_no=*/2, + XLS_ASSIGN_OR_RETURN(int64_t pred_idx, next->predicate_operand_number()); + XLS_RETURN_IF_ERROR(ExpectOperandHasBitsType(next, pred_idx, /*expected_bit_count=*/1)); } + if (!next->function_base()->HasEffectiveProc()) { return absl::InternalError(absl::StrFormat( - "Next node %s (for param %s) is not in a proc", next->GetName(), - next->state_read()->As()->state_element()->name())); + "Next node %s (for state element %s) is not in a proc", + next->GetName(), next->state_element()->name())); } + Proc* proc = next->function_base()->GetEffectiveProcOrDie(); - XLS_ASSIGN_OR_RETURN( - int64_t index, - proc->GetStateElementIndex( - next->state_read()->As()->state_element())); - XLS_RETURN_IF_ERROR(ExpectOperandHasType(next, /*operand_no=*/0, - proc->GetStateElementType(index))); - return ExpectOperandHasType(next, /*operand_no=*/1, // value is operand 1 + XLS_ASSIGN_OR_RETURN(int64_t index, + proc->GetStateElementIndex(next->state_element())); + + if (next->has_state_read_operand()) { + XLS_RETURN_IF_ERROR(ExpectOperandHasType( + next, /*operand_no=*/0, proc->GetStateElementType(index))); + } + + int64_t value_idx = next->has_state_read_operand() ? 1 : 0; + return ExpectOperandHasType(next, value_idx, proc->GetStateElementType(index)); }