Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion xls/ir/function.cc
Original file line number Diff line number Diff line change
Expand Up @@ -329,7 +329,7 @@ absl::Status Function::InternalRebuildSideTables() {
// only held in the side table. We can still check for correctness at least.
// TODO(allight): We should ideally be able to do this.
XLS_RET_CHECK(next_values_.empty());
XLS_RET_CHECK(next_values_by_state_read_.empty());
XLS_RET_CHECK(next_values_by_state_element_.empty());

for (Param* p : params_) {
XLS_RET_CHECK(p->function_base() == this)
Expand Down
48 changes: 38 additions & 10 deletions xls/ir/function_base.cc
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@

#include "absl/algorithm/container.h"
#include "absl/base/casts.h"
#include "absl/base/no_destructor.h"
#include "absl/container/btree_set.h"
#include "absl/container/flat_hash_map.h"
#include "absl/container/flat_hash_set.h"
Expand Down Expand Up @@ -245,7 +246,8 @@ void FunctionBase::TakeOwnershipOfNode(std::unique_ptr<Node>&& node) {
FunctionBase* old_owner = node->function_base();

if (node->Is<StateRead>()) {
old_owner->next_values_by_state_read_.erase(node->As<StateRead>());
old_owner->next_values_by_state_element_.erase(
node->As<StateRead>()->state_element());
}

old_owner->node_iterators_.erase(node.get());
Expand Down Expand Up @@ -442,6 +444,34 @@ absl::StatusOr<Node*> FunctionBase::GetNode(
absl::StrFormat("GetNode(%s) failed.", standard_node_name));
}

const absl::btree_set<Next*, Node::NodeIdLessThan>& FunctionBase::next_values(
StateRead* state_read) const {
return next_values(state_read->state_element());
}

const absl::btree_set<Next*, Node::NodeIdLessThan>& FunctionBase::next_values(
StateElement* state_element) const {
if (!next_values_by_state_element_.contains(state_element)) {
static const absl::NoDestructor<
absl::btree_set<Next*, Node::NodeIdLessThan>>
kEmptySet;
// This should be pretty rare. Basically this should only happen in the
// short time before the non-updated state element is replaced. Just
// returning what is actually there is nicer than crashing however. Do
// check that this is not just some sort of weird corruption however.
CHECK(absl::c_none_of(nodes(),
[state_element](Node* n) {
return n->Is<Next>() &&
n->As<Next>()->state_element() ==
state_element;
}))
<< "Invalid side table for next values. Missing " << state_element
<< " in " << this;
return *kEmptySet;
}
return next_values_by_state_element_.at(state_element);
}

absl::Status FunctionBase::RemoveNode(Node* node) {
XLS_RET_CHECK(node->users().empty()) << node->GetName();
XLS_RET_CHECK(!HasImplicitUse(node)) << node->GetName();
Expand All @@ -462,13 +492,12 @@ absl::Status FunctionBase::RemoveNode(Node* node) {
params_.end());
}
if (node->Is<StateRead>()) {
next_values_by_state_read_.erase(node->As<StateRead>());
next_values_by_state_element_.erase(node->As<StateRead>()->state_element());
}
if (node->Is<Next>()) {
Next* next = node->As<Next>();
if (next->state_read()->Is<StateRead>()) { // Could've been replaced.
StateRead* state_read = next->state_read()->As<StateRead>();
next_values_by_state_read_.at(state_read).erase(next);
if (next_values_by_state_element_.contains(next->state_element())) {
next_values_by_state_element_.at(next->state_element()).erase(next);
}
std::erase(next_values_, next);
}
Expand Down Expand Up @@ -559,13 +588,12 @@ Node* FunctionBase::AddNodeInternal(std::unique_ptr<Node> node) {
params_.push_back(node->As<Param>());
}
if (node->Is<StateRead>()) {
next_values_by_state_read_[node->As<StateRead>()];
next_values_by_state_element_[node->As<StateRead>()->state_element()];
}
if (node->Is<Next>()) {
Next* next = node->As<Next>();
StateRead* state_read = next->state_read()->As<StateRead>();
next_values_.push_back(node->As<Next>());
next_values_by_state_read_[state_read].insert(next);
next_values_.push_back(next);
next_values_by_state_element_[next->state_element()].insert(next);
}
Node* ptr = node.get();
node_iterators_[ptr] = nodes_.insert(nodes_.end(), std::move(node));
Expand Down Expand Up @@ -683,7 +711,7 @@ absl::Status FunctionBase::RebuildSideTables() {
// TODO(allight): The fact that there is so much crap in the function_base
// itself is a problem. Having next's and params' in the function base doesn't
// make a ton of sense.
// NB Because of above the next-values/next_values_by_state_read_ and params
// NB Because of above the next-values/next_values_by_state_element_ and
// lists are updated in proc and function respectively.
// NB We assume that node_iterators_ never gets invalidated.
XLS_RETURN_IF_ERROR(InternalRebuildSideTables());
Expand Down
33 changes: 8 additions & 25 deletions xls/ir/function_base.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,6 @@
#include <utility>
#include <vector>

#include "absl/algorithm/container.h"
#include "absl/base/no_destructor.h"
#include "absl/base/optimization.h"
#include "absl/container/btree_set.h"
#include "absl/container/flat_hash_map.h"
Expand Down Expand Up @@ -249,26 +247,10 @@ class FunctionBase {
absl::Span<Next* const> next_values() const { return next_values_; }

const absl::btree_set<Next*, Node::NodeIdLessThan>& next_values(
StateRead* state_read) const {
if (!next_values_by_state_read_.contains(state_read)) {
// This should be pretty rare. Basically this should only happen in the
// short time before the non-updated state element is replaced. Just
// returning what is actually there is nicer than crashing however. Do
// check that this is not just some sort of weird corruption however.
static const absl::NoDestructor<
absl::btree_set<Next*, Node::NodeIdLessThan>>
kEmptySet;
CHECK(absl::c_none_of(nodes(),
[state_read](Node* n) {
return n->Is<Next>() &&
n->As<Next>()->state_read() == state_read;
}))
<< "Invalid side table for next values. Missing " << state_read
<< " in " << this;
return *kEmptySet;
}
return next_values_by_state_read_.at(state_read);
}
StateRead* state_read) const;

const absl::btree_set<Next*, Node::NodeIdLessThan>& next_values(
StateElement* state_element) const;

// Moves the given param to the given index in the parameter list.
absl::Status MoveParamToIndex(Param* param, int64_t index);
Expand Down Expand Up @@ -497,7 +479,7 @@ class FunctionBase {
// together.
return !n->Is<Param>() && !n->Is<StateRead>();
});
other.next_values_by_state_read_.clear();
other.next_values_by_state_element_.clear();
other.next_values_.clear();
}

Expand Down Expand Up @@ -553,8 +535,9 @@ class FunctionBase {

std::vector<Param*> params_;
std::vector<Next*> next_values_;
absl::flat_hash_map<StateRead*, absl::btree_set<Next*, Node::NodeIdLessThan>>
next_values_by_state_read_;
absl::flat_hash_map<StateElement*,
absl::btree_set<Next*, Node::NodeIdLessThan>>
next_values_by_state_element_;

NameUniquer node_name_uniquer_ =
NameUniquer(/*separator=*/"__", GetIrReservedWords());
Expand Down
10 changes: 8 additions & 2 deletions xls/ir/node.cc
Original file line number Diff line number Diff line change
Expand Up @@ -640,8 +640,14 @@ std::string Node::ToStringInternal(bool include_operand_types) const {
}
case Op::kNext: {
const Next* next = As<Next>();
args = {absl::StrFormat("param=%s", next->state_read()->GetName()),
absl::StrFormat("value=%s", next->value()->GetName())};
if (next->has_state_read_operand()) {
args = {absl::StrFormat("param=%s", next->state_read()->GetName()),
absl::StrFormat("value=%s", next->value()->GetName())};
} else {
args = {
absl::StrFormat("state_element=%s", next->state_element()->name()),
absl::StrFormat("value=%s", next->value()->GetName())};
}
std::optional<Node*> predicate = next->predicate();
if (predicate.has_value()) {
args.push_back(
Expand Down
13 changes: 8 additions & 5 deletions xls/ir/nodes.cc
Original file line number Diff line number Diff line change
Expand Up @@ -988,6 +988,7 @@ Next::Next(const SourceInfo& loc, StateElement* state_element, Node* value,
: Node(Op::kNext, function->package()->GetTupleType({}), loc, name,
function),
state_element_(state_element),
has_state_read_operand_(false),
has_predicate_(predicate.has_value()),
predicate_operand_index_(1),
label_(std::move(label)) {
Expand All @@ -1002,6 +1003,8 @@ Next::Next(const SourceInfo& loc, Node* state_read, Node* value,
std::string_view name, FunctionBase* function)
: Node(Op::kNext, function->package()->GetTupleType({}), loc, name,
function),
state_element_(state_read->As<StateRead>()->state_element()),
has_state_read_operand_(true),
has_predicate_(predicate.has_value()),
predicate_operand_index_(2),
label_(std::move(label)) {
Expand Down Expand Up @@ -1466,17 +1469,17 @@ absl::StatusOr<Node*> StateRead::CloneInNewFunction(

absl::StatusOr<Node*> Next::CloneInNewFunction(
absl::Span<Node* const> new_operands, FunctionBase* new_function) const {
if (state_element_ != nullptr) {
if (has_state_read_operand_) {
return new_function->MakeNodeWithName<Next>(
loc(), state_element_, new_operands[0],
new_operands.size() > 1 ? std::make_optional(new_operands[1])
loc(), new_operands[0], new_operands[1],
new_operands.size() > 2 ? std::make_optional(new_operands[2])
: std::nullopt,
label(), GetNameView());
}
// TODO(meheff): Choose an appropriate name for the cloned node.
return new_function->MakeNodeWithName<Next>(
loc(), new_operands[0], new_operands[1],
new_operands.size() > 2 ? std::make_optional(new_operands[2])
loc(), state_element_, new_operands[0],
new_operands.size() > 1 ? std::make_optional(new_operands[1])
: std::nullopt,
label(), GetNameView());
}
Expand Down
23 changes: 8 additions & 15 deletions xls/ir/nodes.h
Original file line number Diff line number Diff line change
Expand Up @@ -831,17 +831,12 @@ class Next final : public Node {
absl::Span<Node* const> new_operands,
FunctionBase* new_function) const final;
Node* state_read() const {
CHECK(state_element_ == nullptr) << "StateElement is set";
CHECK(has_state_read_operand_)
<< "Next node does not have a state_read operand";
return operand(0);
}

Node* value() const {
if (state_element_ != nullptr) {
return operand(0);
} else {
return operand(1);
}
}
Node* value() const { return operand(has_state_read_operand_ ? 1 : 0); }

const std::optional<std::string>& label() const { return label_; }

Expand Down Expand Up @@ -885,15 +880,13 @@ class Next final : public Node {

bool IsDefinitelyEqualTo(const Node* other) const final;

StateElement* state_element() const {
if (state_element_ != nullptr) {
return state_element_;
}
return state_read()->As<StateRead>()->state_element();
}
bool has_state_read_operand() const { return has_state_read_operand_; }

StateElement* state_element() const { return state_element_; }

private:
StateElement* state_element_ = nullptr;
StateElement* state_element_;
bool has_state_read_operand_;
bool has_predicate_;
const int64_t predicate_operand_index_;
std::optional<std::string> label_;
Expand Down
18 changes: 12 additions & 6 deletions xls/ir/proc.cc
Original file line number Diff line number Diff line change
Expand Up @@ -972,14 +972,20 @@ absl::StatusOr<StateRead*> Proc::TransformStateElement(
nt.old_next->GetName()));
to_replace.push_back({nt.old_next, nxt});
// Identity-ify the old next.
XLS_RETURN_IF_ERROR(nt.old_next->ReplaceOperandNumber(
Next::kValueOperand, nt.old_next->state_read()));
if (nt.old_next->has_state_read_operand()) {
XLS_RETURN_IF_ERROR(nt.old_next->ReplaceOperandNumber(
Next::kValueOperand, nt.old_next->state_read()));
} else {
XLS_RETURN_IF_ERROR(
nt.old_next->ReplaceOperandNumber(/*operand_no=*/0, old_state_read));
}
}
for (const auto& [old_n, new_n] : to_replace) {
XLS_RETURN_IF_ERROR(old_n->ReplaceUsesWith(
new_n,
[&](Node* n) {
if (n->Is<Next>() && n->As<Next>()->state_read() == old_n) {
if (n->Is<Next>() && n->As<Next>()->has_state_read_operand() &&
n->As<Next>()->state_read() == old_n) {
return false;
}
return true;
Expand All @@ -993,7 +999,7 @@ absl::Status Proc::InternalRebuildSideTables() {
XLS_RET_CHECK(params_.empty());
// Why is next-values in base but not elements?
next_values_.clear();
next_values_by_state_read_.clear();
next_values_by_state_element_.clear();
state_reads_.clear();
for (Node* n : nodes()) {
if (n->Is<StateRead>()) {
Expand All @@ -1003,8 +1009,8 @@ absl::Status Proc::InternalRebuildSideTables() {
state_reads_[n->As<StateRead>()->state_element()] = n->As<StateRead>();
} else if (n->Is<Next>()) {
next_values_.push_back(n->As<Next>());
next_values_by_state_read_[n->As<Next>()->state_read()->As<StateRead>()]
.insert(n->As<Next>());
next_values_by_state_element_[n->As<Next>()->state_element()].insert(
n->As<Next>());
}
}
// TODO(allight): We should make it so we can recover channel/proc-inst things
Expand Down
39 changes: 39 additions & 0 deletions xls/ir/proc_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -205,6 +205,45 @@ TEST_F(ProcTest, StatelessProc) {
EXPECT_EQ(proc->DumpIr(), "proc p() {\n}\n");
}

TEST_F(ProcTest, NextValuesByStateElement) {
auto p = CreatePackage();
ProcBuilder pb("p", p.get());
BValue state = pb.StateElement("st", Value(UBits(42, 32)));
BValue add = pb.Add(pb.Literal(UBits(1, 32)), state);
XLS_ASSERT_OK_AND_ASSIGN(Proc * proc, pb.Build({add}));

StateElement* st_elem = proc->GetStateElement(0);
StateRead* st_read = proc->GetStateRead(st_elem);

EXPECT_THAT(proc->next_values(st_elem),
ElementsAre(m::Next(m::StateRead("st"), m::Add())));
EXPECT_THAT(proc->next_values(st_read),
ElementsAre(m::Next(m::StateRead("st"), m::Add())));

// Add another next value for the same state element using the StateElement
// constructor.
XLS_ASSERT_OK_AND_ASSIGN(
Node * literal_10,
proc->MakeNode<Literal>(SourceInfo(), Value(UBits(10, 32))));
XLS_ASSERT_OK_AND_ASSIGN(
Next * next2,
proc->MakeNode<Next>(SourceInfo(), st_elem, literal_10,
/*predicate=*/std::nullopt, /*label=*/std::nullopt));

EXPECT_THAT(
proc->next_values(st_elem),
UnorderedElementsAre(m::Next(m::StateRead("st"), m::Add()), next2));
EXPECT_THAT(
proc->next_values(st_read),
UnorderedElementsAre(m::Next(m::StateRead("st"), m::Add()), next2));

XLS_ASSERT_OK(proc->RemoveNode(next2));
EXPECT_THAT(proc->next_values(st_elem),
ElementsAre(m::Next(m::StateRead("st"), m::Add())));
EXPECT_THAT(proc->next_values(st_read),
ElementsAre(m::Next(m::StateRead("st"), m::Add())));
}

TEST_F(ProcTest, RemoveStateThatStillHasUse) {
// Don't call CreatePackage which creates a VerifiedPackage because we
// intentionally create a malformed proc.
Expand Down
Loading
Loading