Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
78 changes: 50 additions & 28 deletions xls/dslx/ir_convert/function_converter.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2655,25 +2655,29 @@ absl::Status FunctionConverter::HandleCoverBuiltin(const Invocation* node,
absl::Status FunctionConverter::HandleBuiltinRead(const Invocation* node) {
XLS_RETURN_IF_ERROR(ValidateProcState("read", node));
XLS_RET_CHECK_EQ(node->args().size(), 1);
Expr* source = node->args()[0];

BValue active = implicit_token_data_->create_control_predicate();
if (state_read_called_.has_value()) {
BValue state_read_called = state_read_map_[source->ToString()];
if (options_.emit_assert) {
// Assert multiple reads don't happen in same activation.
implicit_token_data_->entry_token = function_builder_->Assert(
implicit_token_data_->entry_token,
function_builder_->Or(function_builder_->Not(active),
function_builder_->Not(*state_read_called_)),
function_builder_->Not(state_read_called)),
"State element read after read in same activation.");
state_read_called_ = function_builder_->Or(*state_read_called_, active);
implicit_token_data_->control_tokens.push_back(
implicit_token_data_->entry_token);
}

Expr* source = node->args()[0];
state_read_map_[source->ToString()] =
function_builder_->Or(state_read_called, active);
XLS_RETURN_IF_ERROR(Visit(source));
XLS_ASSIGN_OR_RETURN(BValue state_element, Use(source));
XLS_ASSIGN_OR_RETURN(BValue state_read, Use(source));
if (node->label().has_value()) {
state_read.node()->As<StateRead>()->set_label(*node->label());
}
Def(node, [&](const SourceInfo& loc) {
return function_builder_->Identity(state_element, loc);
return function_builder_->Identity(state_read, loc);
});
return absl::OkStatus();
}
Expand All @@ -2683,34 +2687,36 @@ absl::Status FunctionConverter::HandleBuiltinWrite(const Invocation* node) {
XLS_RET_CHECK_EQ(node->args().size(), 2);
XLS_RETURN_IF_ERROR(Visit(node->args()[1]));
XLS_ASSIGN_OR_RETURN(BValue update_val, Use(node->args()[1]));
Expr* target = node->args()[0];

BValue active = implicit_token_data_->create_control_predicate();
if (state_read_called_.has_value() && state_write_called_.has_value()) {
BValue state_read_called = state_read_map_[target->ToString()];
BValue state_write_called = state_write_map_[target->ToString()];
if (options_.emit_assert) {
// Assert write doesn't happen before a read
implicit_token_data_->entry_token = function_builder_->Assert(
implicit_token_data_->entry_token,
function_builder_->Or(function_builder_->Not(active),
*state_read_called_),
state_read_called),
"State element written before read in same activation.");

// Assert multiple writes don't happen in same activation
implicit_token_data_->entry_token = function_builder_->Assert(
implicit_token_data_->entry_token,
function_builder_->Or(function_builder_->Not(active),
function_builder_->Not(*state_write_called_)),
function_builder_->Not(state_write_called)),
"State element written after write in same activation.");

state_write_called_ = function_builder_->Or(*state_write_called_, active);
implicit_token_data_->control_tokens.push_back(
implicit_token_data_->entry_token);
}

Expr* target = node->args()[0];
state_write_map_[target->ToString()] =
function_builder_->Or(state_write_called, active);
XLS_RETURN_IF_ERROR(Visit(target));
XLS_ASSIGN_OR_RETURN(BValue state_element, Use(target));
XLS_ASSIGN_OR_RETURN(BValue state_read, Use(target));
ProcBuilder* builder_ptr =
dynamic_cast<ProcBuilder*>(function_builder_.get());
builder_ptr->Next(state_element, update_val, active);
builder_ptr->Next(state_read, update_val, active, node->label());
node_to_ir_[node] = function_builder_->Tuple({});
return absl::OkStatus();
}
Expand Down Expand Up @@ -3722,8 +3728,10 @@ absl::Status FunctionConverter::HandleProcNextFunction(
module_->attributes().contains(ModuleAttribute::kExplicitStateAccess);
BValue state = builder->StateElement(state_name, initial_element);
if (explicit_state_access) {
state_read_called_ = builder->Literal(UBits(0, 1));
state_write_called_ = builder->Literal(UBits(0, 1));
for (Param* p : f->params()) {
state_read_map_[p->identifier()] = builder->Literal(UBits(0, 1));
state_write_map_[p->identifier()] = builder->Literal(UBits(0, 1));
}
}
tokens_.push_back(implicit_token);
auto* builder_ptr = builder.get();
Expand All @@ -3733,14 +3741,32 @@ absl::Status FunctionConverter::HandleProcNextFunction(
XLS_RETURN_IF_ERROR(builder_ptr->SetAsTop());
}

// Set the one state element.
// Set the state element(s).
XLS_RET_CHECK(proc_proto_);
PackageInterfaceProto::NamedValue* state_proto =
proc_proto_.value()->add_state();
*state_proto->mutable_name() = state_name;
*state_proto->mutable_type() = state.GetType()->ToProto();
// State elements aren't emitted in an observable way so no need to track sv
// types.
if (explicit_state_access) {
for (int i = 0; i < f->params().size(); ++i) {
Param* p = f->params()[i];
PackageInterfaceProto::Proc::StateValue* state_value =
proc_proto_.value()->add_state_values();
PackageInterfaceProto::NamedValue* state_proto =
state_value->mutable_name();
state_proto->set_name(p->identifier());
XLS_ASSIGN_OR_RETURN(auto type, ResolveTypeToIr(p->type_annotation()));
*state_proto->mutable_type() = type->ToProto();
Value init = f->params().size() > 1 ? initial_element.elements()[i]
: initial_element;
BValue p_state = builder_ptr->StateElement(p->identifier(), init);
SetNodeToIr(p->name_def(), p_state);
}
} else {
PackageInterfaceProto::NamedValue* state_proto =
proc_proto_.value()->add_state();
*state_proto->mutable_name() = state_name;
*state_proto->mutable_type() = state.GetType()->ToProto();
// Bind the recurrent state element.
XLS_RET_CHECK_EQ(f->params().size(), 1);
SetNodeToIr(f->params()[0]->name_def(), state);
}

// We make an implicit token in case any downstream functions need it; if it's
// unused, it'll be optimized out later.
Expand All @@ -3751,10 +3777,6 @@ absl::Status FunctionConverter::HandleProcNextFunction(
[this]() { return implicit_token_data_->activated; },
};

// Bind the recurrent state element.
XLS_RET_CHECK_EQ(f->params().size(), 1);
SetNodeToIr(f->params()[0]->name_def(), state);

proc_id_ = proc_id;

for (ParametricBinding* parametric_binding : f->parametric_bindings()) {
Expand Down
6 changes: 4 additions & 2 deletions xls/dslx/ir_convert/function_converter.h
Original file line number Diff line number Diff line change
Expand Up @@ -495,6 +495,8 @@ class FunctionConverter {
absl::Status HandleBuiltinEnumerate(const Invocation* node);
absl::Status HandleBuiltinGate(const Invocation* node);
absl::Status HandleBuiltinJoin(const Invocation* node);
absl::Status HandleBuiltinLabeledRead(const Invocation* node);
absl::Status HandleBuiltinLabeledWrite(const Invocation* node);
absl::Status HandleBuiltinOneHot(const Invocation* node);
absl::Status HandleBuiltinOneHotSel(const Invocation* node);
absl::Status HandleBuiltinOrReduce(const Invocation* node);
Expand Down Expand Up @@ -666,8 +668,8 @@ class FunctionConverter {
// A predicate for whether any code path leading to a state read or write has
// been taken each proc activation. This gets updated to accumulate terms
// every time we emit a state read/write.
std::optional<BValue> state_read_called_;
std::optional<BValue> state_write_called_;
absl::flat_hash_map<std::string, BValue> state_read_map_;
absl::flat_hash_map<std::string, BValue> state_write_map_;
};

} // namespace xls::dslx
Expand Down
135 changes: 123 additions & 12 deletions xls/dslx/ir_convert/ir_converter_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -8342,7 +8342,7 @@ proc p {
}

TEST_F(IrConverterTest, ExplicitStateAccessMultipleReads) {
constexpr std::string_view program = R"(#![feature(explicit_state_access)]
constexpr std::string_view kModule = R"(#![feature(explicit_state_access)]
proc main {

config() { () }
Expand All @@ -8356,12 +8356,12 @@ proc main {
}
)";
XLS_ASSERT_OK_AND_ASSIGN(std::string converted,
ConvertModuleForTest(program));
ConvertModuleForTest(kModule));
ExpectIr(converted);
}

TEST_F(IrConverterTest, ExplicitStateAccessMultipleWrites) {
constexpr std::string_view program = R"(#![feature(explicit_state_access)]
constexpr std::string_view kModule = R"(#![feature(explicit_state_access)]
proc main {
config() { () }

Expand All @@ -8376,12 +8376,12 @@ proc main {
}
)";
XLS_ASSERT_OK_AND_ASSIGN(std::string converted,
ConvertModuleForTest(program));
ConvertModuleForTest(kModule));
ExpectIr(converted);
}

TEST_F(IrConverterTest, ExplicitStateAccessWriteBeforeRead) {
constexpr std::string_view program = R"(#![feature(explicit_state_access)]
constexpr std::string_view kModule = R"(#![feature(explicit_state_access)]
proc main {
config() { () }

Expand All @@ -8393,12 +8393,12 @@ proc main {
}
)";
XLS_ASSERT_OK_AND_ASSIGN(std::string converted,
ConvertModuleForTest(program));
ConvertModuleForTest(kModule));
ExpectIr(converted);
}

TEST_F(IrConverterTest, ExplicitStateAccessConditionalWrite) {
constexpr std::string_view program = R"(#![feature(explicit_state_access)]
constexpr std::string_view kModule = R"(#![feature(explicit_state_access)]
proc main {
config() { () }

Expand All @@ -8415,12 +8415,12 @@ proc main {
}
)";
XLS_ASSERT_OK_AND_ASSIGN(std::string converted,
ConvertModuleForTest(program));
ConvertModuleForTest(kModule));
ExpectIr(converted);
}

TEST_F(IrConverterTest, ExplicitStateAccessMatch) {
constexpr std::string_view program = R"(#![feature(explicit_state_access)]
constexpr std::string_view kModule = R"(#![feature(explicit_state_access)]
proc main {
config() { () }
init { 0 }
Expand All @@ -8442,12 +8442,12 @@ proc main {
}
)";
XLS_ASSERT_OK_AND_ASSIGN(std::string converted,
ConvertModuleForTest(program));
ConvertModuleForTest(kModule));
ExpectIr(converted);
}

TEST_F(IrConverterTest, ExplicitStateAccessMatchMultipleWrites) {
constexpr std::string_view program = R"(#![feature(explicit_state_access)]
constexpr std::string_view kModule = R"(#![feature(explicit_state_access)]
proc main {
config() { () }
init { true }
Expand All @@ -8467,7 +8467,118 @@ proc main {
}
)";
XLS_ASSERT_OK_AND_ASSIGN(std::string converted,
ConvertModuleForTest(program));
ConvertModuleForTest(kModule));
ExpectIr(converted);
}

TEST_F(IrConverterTest, ExplicitStateAccessLabeledReadAndWrite) {
constexpr std::string_view kModule = R"(#![feature(explicit_state_access)]
proc main {
init { 0 }
config() { }
next(state: u32) {
let x = 'main_read:read(state);
let y = x + 1;
'main_write:write(state, y);
}
}
)";
XLS_ASSERT_OK_AND_ASSIGN(std::string converted,
ConvertModuleForTest(kModule));
ExpectIr(converted);
}

TEST_F(IrConverterTest, ExplicitStateAccessReadWithLabeledRead) {
constexpr std::string_view kModule = R"(#![feature(explicit_state_access)]
proc main {
init { 0 }
config() { }
next(state: u32) {
let curr = read(state);
let x = 'main_read:read(state);
let y = x + 1 + curr;
'main_write:write(state, y);
}
}
)";
XLS_ASSERT_OK_AND_ASSIGN(std::string converted,
ConvertModuleForTest(kModule));
ExpectIr(converted);
}

TEST_F(IrConverterTest, ExplicitStateAccessMultipleStates) {
constexpr std::string_view kModule = R"(#![feature(explicit_state_access)]
struct Point {
x: u32,
y: u32,
}
proc main {
init { (Point { x: 0, y: 1 }, (2, 3), 4) }
config() { }
next(state_0: Point, state_1: (u32, u32), state_2: u32) {
let a = read(state_0);
let b = read(state_1);
let c = read(state_2);
let new_a = Point { x: a.x + 1, y: b.1 + c };
let new_b = (b.0 + 1, c + 1);
write(state_0, new_a);
write(state_1, new_b);
write(state_2, c + 2);
}
}
)";
XLS_ASSERT_OK_AND_ASSIGN(std::string converted,
ConvertModuleForTest(kModule));
ExpectIr(converted);
}

TEST_F(IrConverterTest, ExplicitStateAccessMultipleStatesMultipleReads) {
constexpr std::string_view kModule = R"(#![feature(explicit_state_access)]
proc main {
init { (0, 1) }
config() { }
next(state_0: u32, state_1: u32) {
let b_0 = read(state_1);
let b_1 = read(state_1);
}
}
)";
XLS_ASSERT_OK_AND_ASSIGN(std::string converted,
ConvertModuleForTest(kModule));
ExpectIr(converted);
}

TEST_F(IrConverterTest, ExplicitStateAccessMultipleStatesMultipleWrites) {
constexpr std::string_view kModule = R"(#![feature(explicit_state_access)]
proc main {
init { (0, 1) }
config() { }
next(state_0: u32, state_1: u32) {
let a = read(state_0);
let b = read(state_1);
write(state_1, a + b);
write(state_1, a + 1);
}
}
)";
XLS_ASSERT_OK_AND_ASSIGN(std::string converted,
ConvertModuleForTest(kModule));
ExpectIr(converted);
}

TEST_F(IrConverterTest, ExplicitStateAccessMultipleStatesWriteBeforeRead) {
constexpr std::string_view kModule = R"(#![feature(explicit_state_access)]
proc main {
init { (0, 1) }
config() { }
next(state_0: u32, state_1: u32) {
let a = read(state_0);
write(state_1, a);
}
}
)";
XLS_ASSERT_OK_AND_ASSIGN(std::string converted,
ConvertModuleForTest(kModule));
ExpectIr(converted);
}

Expand Down
Loading
Loading