Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions xls/contrib/xlscc/generate_fsm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -837,8 +837,12 @@ NewFSMGenerator::GenerateNewFSMInvocation(
xls_state_element = pb.StateElement(
state_element.name, xls::ZeroOfType(state_element.type), body_loc);
} else {
xls::StateRead* state_read = pb.proc()->GetStateReadByStateElement(
state_element.existing_state_element);
absl::Span<xls::StateRead* const> reads =
pb.proc()->GetStateReadsByStateElement(
state_element.existing_state_element);
XLSCC_CHECK(!reads.empty(), body_loc);
XLSCC_CHECK_LE(reads.size(), 1, body_loc);
xls::StateRead* state_read = reads.front();
xls_state_element = TrackedBValue(state_read, &pb);
}

Expand Down
27 changes: 21 additions & 6 deletions xls/contrib/xlscc/translate_block.cc
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
#include "absl/strings/str_cat.h"
#include "absl/strings/str_format.h"
#include "absl/strings/str_join.h"
#include "absl/types/span.h"
#include "clang/include/clang/AST/Attr.h"
#include "clang/include/clang/AST/Attrs.inc"
#include "clang/include/clang/AST/Decl.h"
Expand Down Expand Up @@ -258,8 +259,11 @@ absl::StatusOr<TrackedBValue> ComposeStaticValueInput(
if (!generate_new_fsm || !TypeIsDecomposable(xls_type)) {
xls::StateElement* state_element = state_element_for_static.at(
DeclLeaf{.decl = namedecl, .leaf_index = -1});
return TrackedBValue(pb.proc()->GetStateReadByStateElement(state_element),
&pb);
absl::Span<xls::StateRead* const> reads =
pb.proc()->GetStateReadsByStateElement(state_element);
CHECK(!reads.empty());
CHECK_LE(reads.size(), 1);
return TrackedBValue(reads.front(), &pb);
}
absl::InlinedVector<xls::Type*, 1> decomposed_types =
DecomposeTupleTypes(xls_type);
Expand All @@ -268,7 +272,11 @@ absl::StatusOr<TrackedBValue> ComposeStaticValueInput(
for (int64_t i = 0; i < decomposed_types.size(); ++i) {
xls::StateElement* decomposed_element = state_element_for_static.at(
DeclLeaf{.decl = namedecl, .leaf_index = i});
nodes.push_back(pb.proc()->GetStateReadByStateElement(decomposed_element));
absl::Span<xls::StateRead* const> reads =
pb.proc()->GetStateReadsByStateElement(decomposed_element);
CHECK(!reads.empty());
CHECK_LE(reads.size(), 1);
nodes.push_back(reads.front());
}

XLS_ASSIGN_OR_RETURN(xls::Node * node,
Expand Down Expand Up @@ -662,8 +670,11 @@ absl::StatusOr<xls::Proc*> Translator::GenerateIR_Block(
next_state_value.value = TrackedBValue(decomposed_next_val, &pb);
} else {
XLSCC_CHECK_EQ(decomposed_elems.size(), 1, body_loc);
xls::StateRead* state_read =
pb.proc()->GetStateReadByStateElement(decomposed_elem);
absl::Span<xls::StateRead* const> reads =
pb.proc()->GetStateReadsByStateElement(decomposed_elem);
XLSCC_CHECK(!reads.empty(), body_loc);
XLSCC_CHECK_LE(reads.size(), 1, body_loc);
xls::StateRead* state_read = reads.front();
TrackedBValue prev_val(state_read, &pb);
next_state_value.value =
pb.And(prev_val,
Expand Down Expand Up @@ -2268,7 +2279,11 @@ absl::StatusOr<xls::Proc*> Translator::BuildWithNextStateValueMap(
return absl::InternalError(
absl::StrFormat("No next values for state element %s", elem->name()));
}
xls::StateRead* state_read = pb.proc()->GetStateReadByStateElement(elem);
absl::Span<xls::StateRead* const> reads =
pb.proc()->GetStateReadsByStateElement(elem);
XLSCC_CHECK(!reads.empty(), loc);
XLSCC_CHECK_LE(reads.size(), 1, loc);
xls::StateRead* state_read = reads.front();
TrackedBValue read_bval(state_read, &pb);
if (values_for_elem == 1) {
const NextStateValue& next_state_value =
Expand Down
8 changes: 6 additions & 2 deletions xls/contrib/xlscc/translate_loops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
#include "absl/status/statusor.h"
#include "absl/strings/str_format.h"
#include "absl/time/time.h"
#include "absl/types/span.h"
#include "clang/include/clang/AST/Attr.h"
#include "clang/include/clang/AST/Decl.h"
#include "clang/include/clang/AST/Expr.h"
Expand Down Expand Up @@ -1374,8 +1375,11 @@ Translator::GenerateIR_PipelinedLoopContents(
} else {
xls::StateElement* state_elem =
prepared.state_element_for_variable.at(DeclLeaf{.decl = decl});
state_reads_by_decl[decl] =
TrackedBValue(pb.proc()->GetStateReadByStateElement(state_elem), &pb);
absl::Span<xls::StateRead* const> reads =
pb.proc()->GetStateReadsByStateElement(state_elem);
XLSCC_CHECK(!reads.empty(), loc);
XLSCC_CHECK_LE(reads.size(), 1, loc);
state_reads_by_decl[decl] = TrackedBValue(reads.front(), &pb);
}
}

Expand Down
6 changes: 5 additions & 1 deletion xls/contrib/xlscc/unit_tests/unit_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -938,7 +938,11 @@ XlsccTestBase::GetStatesByIONodeForFSMProc(std::string_view func_name) {

CHECK_EQ(found_proc_with_fsm, nullptr);
found_proc_with_fsm = proc.get();
fsm_state_read = proc->GetStateReadByStateElement(state_element);
absl::Span<xls::StateRead* const> reads =
proc->GetStateReadsByStateElement(state_element);
CHECK(!reads.empty());
CHECK_LE(reads.size(), 1);
fsm_state_read = reads.front();

CHECK_NE(found_proc_with_fsm, nullptr);
CHECK_NE(fsm_state_read, nullptr);
Expand Down
1 change: 1 addition & 0 deletions xls/ir/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -931,6 +931,7 @@ cc_test(
"@com_google_absl//absl/base",
"@com_google_absl//absl/status",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/types:span",
"@googletest//:gtest",
],
)
Expand Down
8 changes: 5 additions & 3 deletions xls/ir/ir_parser.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1895,9 +1895,11 @@ absl::StatusOr<Parser::BodyResult> Parser::ParseBody(
Proc * source_proc,
ParseProc(package, /*outer_attributes=*/{}, &source));
for (StateElement* element : source_proc->StateElements()) {
name_to_value->emplace(
element->name(),
bb->SourceNode(source_proc->GetStateReadByStateElement(element)));
absl::Span<StateRead* const> reads =
source_proc->GetStateReadsByStateElement(element);
XLS_RET_CHECK_EQ(reads.size(), 1);
name_to_value->emplace(element->name(),
bb->SourceNode(reads.front()));
}
} else {
return absl::InvalidArgumentError(absl::StrFormat(
Expand Down
18 changes: 13 additions & 5 deletions xls/ir/ir_parser_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
#include "absl/strings/ascii.h"
#include "absl/strings/str_cat.h"
#include "absl/strings/substitute.h"
#include "absl/types/span.h"
#include "xls/common/source_location.h"
#include "xls/common/status/matchers.h"
#include "xls/ir/bits.h"
Expand Down Expand Up @@ -607,7 +608,9 @@ proc foo( x: bits[32], y: (), z: bits[32], init={42, (), 123}) {
XLS_ASSERT_OK_AND_ASSIGN(Proc * proc, package->GetProc("foo"));
EXPECT_EQ(proc->GetStateElementCount(), 3);
XLS_ASSERT_OK_AND_ASSIGN(StateElement * x, proc->GetStateElementByName("x"));
EXPECT_THAT(proc->GetStateReadByStateElement(x)->predicate(), std::nullopt);
absl::Span<StateRead* const> reads = proc->GetStateReadsByStateElement(x);
ASSERT_EQ(reads.size(), 1);
EXPECT_THAT(reads.front()->predicate(), std::nullopt);
}

TEST(IrParserTest, ProcWithPredicatedStateRead) {
Expand All @@ -626,17 +629,22 @@ proc foo( x: bits[32], y: bits[1], z: bits[32], init={42, 1, 123}) {
EXPECT_EQ(proc->GetStateElementCount(), 3);

XLS_ASSERT_OK_AND_ASSIGN(StateElement * x, proc->GetStateElementByName("x"));
std::optional<Node*> x_predicate =
proc->GetStateReadByStateElement(x)->predicate();
absl::Span<StateRead* const> reads_x = proc->GetStateReadsByStateElement(x);
ASSERT_EQ(reads_x.size(), 1);
std::optional<Node*> x_predicate = reads_x.front()->predicate();
ASSERT_TRUE(x_predicate.has_value());
ASSERT_EQ((*x_predicate)->op(), Op::kStateRead);
EXPECT_EQ((*x_predicate)->As<StateRead>()->state_element()->name(), "y");

XLS_ASSERT_OK_AND_ASSIGN(StateElement * y, proc->GetStateElementByName("y"));
ASSERT_FALSE(proc->GetStateReadByStateElement(y)->predicate().has_value());
absl::Span<StateRead* const> reads_y = proc->GetStateReadsByStateElement(y);
ASSERT_EQ(reads_y.size(), 1);
ASSERT_FALSE(reads_y.front()->predicate().has_value());

XLS_ASSERT_OK_AND_ASSIGN(StateElement * z, proc->GetStateElementByName("z"));
ASSERT_FALSE(proc->GetStateReadByStateElement(z)->predicate().has_value());
absl::Span<StateRead* const> reads_z = proc->GetStateReadsByStateElement(z);
ASSERT_EQ(reads_z.size(), 1);
ASSERT_FALSE(reads_z.front()->predicate().has_value());
}

TEST(IrParserTest, ParseSendReceiveChannel) {
Expand Down
4 changes: 2 additions & 2 deletions xls/ir/node_util_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -381,7 +381,7 @@ TEST_F(NodeUtilTest, ChannelNodes) {

EXPECT_THAT(GetChannelUsedByNode(rcv.node()), IsOkAndHolds(ch0));
EXPECT_THAT(GetChannelUsedByNode(send.node()), IsOkAndHolds(ch1));
EXPECT_THAT(GetChannelUsedByNode(proc->GetStateRead(0)),
EXPECT_THAT(GetChannelUsedByNode(proc->GetStateReads(0).front()),
StatusIs(absl::StatusCode::kNotFound,
HasSubstr("No channel associated with node")));
}
Expand Down Expand Up @@ -435,7 +435,7 @@ TEST_F(NodeUtilTest, ReplaceTupleIndicesWorksWithToken) {
// works, we'd need to make an after_all and add the receive's output token to
// it after calling ReplaceTupleElementsWith().
XLS_EXPECT_OK(ReplaceTupleElementsWith(
receive_node, {{0, proc->GetStateRead(0)}, {1, lit0}}));
receive_node, {{0, proc->GetStateReads(0).front()}, {1, lit0}}));

ExpectIr(proc->DumpIr(), TestName());
}
Expand Down
68 changes: 43 additions & 25 deletions xls/ir/proc.cc
Original file line number Diff line number Diff line change
Expand Up @@ -214,13 +214,15 @@ absl::Status Proc::RemoveStateElement(int64_t index) {
StateElement* old_state_element = GetStateElement(index);
auto old_state_read_it = state_reads_.find(old_state_element);
XLS_RET_CHECK(old_state_read_it != state_reads_.end());
if (!old_state_read_it->second->users().empty()) {
return absl::InvalidArgumentError(absl::StrFormat(
"Cannot remove state element %d of proc %s, existing "
"state read %s has uses",
index, name(), old_state_read_it->second->GetNameView()));
for (StateRead* read : old_state_read_it->second) {
if (!read->users().empty()) {
return absl::InvalidArgumentError(
absl::StrFormat("Cannot remove state element %d of proc %s, existing "
"state read %s has uses",
index, name(), read->GetNameView()));
}
XLS_RETURN_IF_ERROR(RemoveNode(read));
}
XLS_RETURN_IF_ERROR(RemoveNode(old_state_read_it->second));
// TODO(allight): This should ideally not need to be done manually.
state_reads_.erase(old_state_read_it);

Expand All @@ -232,11 +234,14 @@ absl::Status Proc::RemoveStateElement(int64_t index) {
absl::Status Proc::RemoveAllStateElements() {
// TODO(allight): This relies on side tables being valid. For now just let it
// go.
for (const auto& [elem, read] : state_reads_) {
if (read != nullptr) {
XLS_RETURN_IF_ERROR(RemoveNode(read))
<< "Cannot remove " << elem->ToString() << " of proc " << name()
<< " because read '" << read->ToString() << "' could not be removed.";
for (const auto& [elem, reads] : state_reads_) {
for (StateRead* read : reads) {
if (read != nullptr) {
XLS_RETURN_IF_ERROR(RemoveNode(read))
<< "Cannot remove " << elem->ToString() << " of proc " << name()
<< " because read '" << read->ToString()
<< "' could not be removed.";
}
}
XLS_RETURN_IF_ERROR(state_name_uniquer_.ReleaseIdentifier(elem->name()))
<< "Cannot release name of " << elem->ToString();
Expand Down Expand Up @@ -278,7 +283,7 @@ absl::StatusOr<StateRead*> Proc::InsertStateElement(
MakeNodeWithName<StateRead>(
loc, state_element, read_predicate,
/*label=*/std::nullopt, state_element->name()));
state_reads_[state_element] = state_read;
state_reads_[state_element].push_back(state_read);

if (next_state.has_value()) {
if (!ValueConformsToType(init_value, next_state.value()->GetType())) {
Expand Down Expand Up @@ -351,14 +356,13 @@ absl::StatusOr<Proc*> Proc::Clone(
return mapping.at(orig);
};
for (StateElement* state_element : StateElements()) {
StateRead* state_read = state_reads_.at(state_element);
XLS_ASSIGN_OR_RETURN(
StateRead * cloned_state_read,
cloned_proc->AppendStateElement(
remap_name(state_name_remapping, state_element->name()),
state_element->initial_value(), state_read->predicate(),
/*next_state=*/std::nullopt));
original_to_clone[state_read] = cloned_state_read;
XLS_RETURN_IF_ERROR(
cloned_proc
->InsertUnreadStateElement(
cloned_proc->GetStateElementCount(),
remap_name(state_name_remapping, state_element->name()),
state_element->initial_value())
.status());
}
if (is_new_style_proc()) {
absl::flat_hash_map<ChannelInterface*, ChannelInterface*> channel_map;
Expand Down Expand Up @@ -445,7 +449,23 @@ absl::StatusOr<Proc*> Proc::Clone(

switch (node->op()) {
case Op::kStateRead: {
continue;
StateRead* src = node->As<StateRead>();
StateElement* src_elem = src->state_element();
XLS_ASSIGN_OR_RETURN(int64_t idx, GetStateElementIndex(src_elem));
StateElement* cloned_elem = cloned_proc->GetStateElement(idx);

std::optional<Node*> cloned_predicate;
if (src->predicate().has_value()) {
cloned_predicate = original_to_clone.at(src->predicate().value());
}

XLS_ASSIGN_OR_RETURN(StateRead * cloned_state_read,
cloned_proc->MakeNodeWithName<StateRead>(
src->loc(), cloned_elem, cloned_predicate,
/*label=*/std::nullopt, cloned_elem->name()));
cloned_proc->state_reads_[cloned_elem].push_back(cloned_state_read);
original_to_clone[node] = cloned_state_read;
break;
}
case Op::kReceive: {
Receive* src = node->As<Receive>();
Expand Down Expand Up @@ -1000,10 +1020,8 @@ absl::Status Proc::InternalRebuildSideTables() {
state_reads_.clear();
for (Node* n : nodes()) {
if (n->Is<StateRead>()) {
XLS_RET_CHECK(!state_reads_.contains(n->As<StateRead>()->state_element()))
<< "Duplicate state element read: "
<< n->As<StateRead>()->state_element();
state_reads_[n->As<StateRead>()->state_element()] = n->As<StateRead>();
state_reads_[n->As<StateRead>()->state_element()].push_back(
n->As<StateRead>());
} else if (n->Is<Next>()) {
next_values_.push_back(n->As<Next>());
next_values_by_state_element_[n->As<Next>()->state_element()].insert(
Expand Down
19 changes: 16 additions & 3 deletions xls/ir/proc.h
Original file line number Diff line number Diff line change
Expand Up @@ -95,10 +95,23 @@ class Proc : public FunctionBase {
return state_elements_.contains(name);
}

// Remove legacy getters after all downstream passes migrate logic.
StateRead* GetStateRead(int64_t index) const {
return state_reads_.at(GetStateElement(index));
return GetStateReads(index).front();
}

StateRead* GetStateReadByStateElement(StateElement* state_element) const {
return GetStateReadsByStateElement(state_element).front();
}

// Get state reads for a state element at the given index.
absl::Span<StateRead* const> GetStateReads(int64_t index) const {
return state_reads_.at(GetStateElement(index));
}

// Get state reads for a state element.
absl::Span<StateRead* const> GetStateReadsByStateElement(
StateElement* state_element) const {
return state_reads_.at(state_element);
}

Expand Down Expand Up @@ -403,8 +416,8 @@ class Proc : public FunctionBase {
absl::flat_hash_map<std::string, std::unique_ptr<StateElement>>
state_elements_;

// Map of the unique StateRead node for each state element.
absl::flat_hash_map<StateElement*, StateRead*> state_reads_;
// Map of StateRead nodes for each state element.
absl::flat_hash_map<StateElement*, std::vector<StateRead*>> state_reads_;

// Vector of state element pointers. Kept in sync with the state_elements_
// map. Enables easy, stable iteration over state elements. With this vector,
Expand Down
Loading
Loading