Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions xls/dslx/fmt/ast_fmt.cc
Original file line number Diff line number Diff line change
Expand Up @@ -417,6 +417,12 @@ DocRef Fmt(const ChannelTypeAnnotation& n, Comments& comments,
return ConcatNGroup(arena, pieces);
}

DocRef Fmt(const TypeVariableTypeAnnotation& n, Comments& comments,
DocArena& arena) {
std::vector<DocRef> pieces = {Fmt(*n.type_variable(), comments, arena)};
return ConcatNGroup(arena, pieces);
}

DocRef Fmt(const TypeAnnotation& n, Comments& comments, DocArena& arena) {
if (auto* t = dynamic_cast<const BuiltinTypeAnnotation*>(&n)) {
return Fmt(*t, comments, arena);
Expand All @@ -433,6 +439,12 @@ DocRef Fmt(const TypeAnnotation& n, Comments& comments, DocArena& arena) {
if (auto* t = dynamic_cast<const ChannelTypeAnnotation*>(&n)) {
return Fmt(*t, comments, arena);
}
if (auto* t = dynamic_cast<const TypeVariableTypeAnnotation*>(&n)) {
return Fmt(*t, comments, arena);
}
if (dynamic_cast<const GenericTypeAnnotation*>(&n)) {
return arena.Make(Keyword::kType);
}
if (dynamic_cast<const SelfTypeAnnotation*>(&n)) {
return arena.Make(Keyword::kSelfType);
}
Expand Down
3 changes: 2 additions & 1 deletion xls/dslx/frontend/parser_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -4077,7 +4077,8 @@ TEST_F(ParserTest, ParseMapWithLambdaNoParamAnnotation) {

TEST_F(ParserTest, ParseMapWithLambdaNoParamAnnotationMultipleParams) {
RoundTrip(
R"(const ARR = map(enumerate(range(0, u16:5)), |i, j| { 2 * i * j });)");
R"(import std;
const ARR = map(std::enumerate(range(0, u16:5)), |i, j| { 2 * i * j });)");
}

TEST_F(ParserTest, ParseLambdaWithNoBrackets) {
Expand Down
3 changes: 2 additions & 1 deletion xls/dslx/ir_convert/ir_converter_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -899,8 +899,9 @@ fn main(input: u8[2]) -> u8[2] {

TEST_F(IrConverterTest, ArrayEnumerate) {
constexpr std::string_view program = R"(
import std;
fn main(array: u8[4]) -> (u32, u8)[4] {
enumerate(array)
std::enumerate(array)
}
)";
XLS_ASSERT_OK_AND_ASSIGN(std::string converted,
Expand Down
36 changes: 21 additions & 15 deletions xls/dslx/ir_convert/testdata/ir_converter_test_ArrayEnumerate.ir
Original file line number Diff line number Diff line change
@@ -1,19 +1,25 @@
package test_module

file_number 0 "test_module.x"
file_number 0 "xls/dslx/stdlib/std.x"
file_number 1 "test_module.x"

top fn __test_module__main(array: bits[8][4] id=1) -> (bits[32], bits[8])[4] {
literal.2: bits[32] = literal(value=0, id=2)
literal.5: bits[32] = literal(value=1, id=5)
literal.8: bits[32] = literal(value=2, id=8)
literal.11: bits[32] = literal(value=3, id=11)
array_index.3: bits[8] = array_index(array, indices=[literal.2], id=3)
array_index.6: bits[8] = array_index(array, indices=[literal.5], id=6)
array_index.9: bits[8] = array_index(array, indices=[literal.8], id=9)
array_index.12: bits[8] = array_index(array, indices=[literal.11], id=12)
tuple.4: (bits[32], bits[8]) = tuple(literal.2, array_index.3, id=4)
tuple.7: (bits[32], bits[8]) = tuple(literal.5, array_index.6, id=7)
tuple.10: (bits[32], bits[8]) = tuple(literal.8, array_index.9, id=10)
tuple.13: (bits[32], bits[8]) = tuple(literal.11, array_index.12, id=13)
ret array.14: (bits[32], bits[8])[4] = array(tuple.4, tuple.7, tuple.10, tuple.13, id=14)
fn ____std__enumerate__4_u8_counted_for_0_body(i: bits[32] id=7, result: (bits[32], bits[8])[4] id=10, x: bits[8][4] id=11) -> (bits[32], bits[8])[4] {
literal.8: bits[32] = literal(value=0, id=8)
add.9: bits[32] = add(i, literal.8, id=9)
array_index.12: bits[8] = array_index(x, indices=[add.9], id=12)
tuple.13: (bits[32], bits[8]) = tuple(add.9, array_index.12, id=13)
ret array_update.14: (bits[32], bits[8])[4] = array_update(result, tuple.13, indices=[add.9], id=14)
}

fn __std__enumerate__4_u8(x: bits[8][4] id=1) -> (bits[32], bits[8])[4] {
literal.3: bits[32] = literal(value=0, id=3)
literal.4: bits[8] = literal(value=0, id=4)
tuple.5: (bits[32], bits[8]) = tuple(literal.3, literal.4, id=5)
array.6: (bits[32], bits[8])[4] = array(tuple.5, tuple.5, tuple.5, tuple.5, id=6)
N: bits[32] = literal(value=4, id=2)
ret counted_for.15: (bits[32], bits[8])[4] = counted_for(array.6, trip_count=4, stride=1, body=____std__enumerate__4_u8_counted_for_0_body, invariant_args=[x], id=15)
}

top fn __test_module__main(array: bits[8][4] id=16) -> (bits[32], bits[8])[4] {
ret invoke.17: (bits[32], bits[8])[4] = invoke(array, to_apply=__std__enumerate__4_u8, id=17)
}
37 changes: 37 additions & 0 deletions xls/dslx/stdlib/std.x
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,43 @@

// DSLX standard library routines.

#![feature(generics)]

pub fn enumerate<T: type, N: u32>(x: T[N]) -> (u32, T)[N] {
for (i, result) in 0..N {
update(result, i, (i, x[i]))
}([(u32:0, zero!<T>()), ...])
}

#[test]
fn emumerate_test() {
let array = [1, 2, 4, 8];
let enumerated = enumerate(array);
assert_eq(enumerated[0], (0, 1));
assert_eq(enumerated[1], (1, 2));
assert_eq(enumerated[2], (2, 4));
assert_eq(enumerated[3], (3, 8));
}

#[test]
fn enumerate_type_test() {
type RamData = uN[8];
const DATA = [RamData:1, 2, 4, 8];
let enumerated = enumerate(DATA);
assert_eq(enumerated[0], (0, 1));
assert_eq(enumerated[1], (1, 2));
assert_eq(enumerated[2], (2, 4));
assert_eq(enumerated[3], (3, 8));
}

#[test]
fn enumerate_tuple_test() {
let x = [(true, 3), (false, 2)];
let enumerated = enumerate(x);
assert_eq(enumerated[0], (0, (true, 3)));
assert_eq(enumerated[1], (1, (false, 2)));
}

pub fn sizeof<S: bool, N: u32>(x: xN[S][N]) -> u32 { N }

#[test]
Expand Down
3 changes: 2 additions & 1 deletion xls/dslx/type_system/typecheck_module_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2067,9 +2067,10 @@ fn f() -> u32[3] {

TEST_F(TypecheckV2Test, EnumerateBuiltin) {
XLS_EXPECT_OK(Typecheck(R"(
import std;
type MyTup = (u32, u2);
fn f(x: u2[7]) -> MyTup[7] {
enumerate(x)
std::enumerate(x)
}
)"));
}
Expand Down
4 changes: 3 additions & 1 deletion xls/examples/dslx_intro/prefix_scan_equality.x
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,12 @@

// Prefix scans an array of 8 32-bit values and produces a running count of
// duplicate values in the run.
import std;

fn prefix_scan_eq(x: u32[8]) -> u3[8] {
let (_, _, result) =
for ((i, elem), (prior, count, result)): ((u32, u32), (u32, u3, u3[8]))
in enumerate(x) {
in std::enumerate(x) {
let (to_place, new_count): (u3, u3) = match (i == u32:0, prior == elem) {
// The first iteration always places 0 and propagates seen count of 1.
(true, _) => (u3:0, u3:1),
Expand Down
8 changes: 4 additions & 4 deletions xls/modules/rle/rle_dec.x
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,7 @@ proc RunLengthDecoderTransactionTest {
];
let tok = for ((counter, stimulus), tok):
((u32, (TestSymbol, TestCount)) , token)
in enumerate(TransactionTestStimuli) {
in std::enumerate(TransactionTestStimuli) {
let last = counter == (array_size(TransactionTestStimuli) - u32:1);
let data_in = TestDecInData{
symbol: stimulus.0,
Expand All @@ -191,7 +191,7 @@ proc RunLengthDecoderTransactionTest {
];
let tok = for ((counter, symbol), tok):
((u32, TestSymbol) , token)
in enumerate(TransationTestOutputs) {
in std::enumerate(TransationTestOutputs) {
let last = counter == (array_size(TransationTestOutputs) - u32:1);
let data_out = TestDecOutData{
symbol: symbol,
Expand Down Expand Up @@ -243,7 +243,7 @@ proc RunLengthDecoderLastAfterLastTest {
];
let tok = for ((counter, stimulus), tok):
((u32, TestDecInData) , token)
in enumerate(LastAfterLastTestStimuli) {
in std::enumerate(LastAfterLastTestStimuli) {
let tok = send(tok, dec_input_s, stimulus);
trace_fmt!("Sent {} stimuli, symbol: 0x{:x}, count:{}, last: {}",
counter + u32:1, stimulus.symbol, stimulus.count, stimulus.last);
Expand All @@ -255,7 +255,7 @@ proc RunLengthDecoderLastAfterLastTest {
];
let tok = for ((counter, output), tok):
((u32, TestDecOutData) , token)
in enumerate(LastAfterLastTestOutputs) {
in std::enumerate(LastAfterLastTestOutputs) {
let (tok, dec_output) = recv(tok, dec_output_r);
trace_fmt!(
"Received {} transactions, symbol: 0x{:x}, last: {}",
Expand Down
16 changes: 8 additions & 8 deletions xls/modules/rle/rle_enc.x
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,7 @@ proc RunLengthEncoderCountSymbolTest {
];
let tok = for ((counter, symbol), tok):
((u32, CountSymbolTestStimulus) , token)
in enumerate(CountSymbolTestTestStimuli) {
in std::enumerate(CountSymbolTestTestStimuli) {
let last = counter == (array_size(CountSymbolTestTestStimuli) - u32:1);
let stimulus = CountSymbolTestEncInData{symbol: symbol, last: last};
let tok = send(tok, enc_input_s, stimulus);
Expand All @@ -210,7 +210,7 @@ proc RunLengthEncoderCountSymbolTest {
];
let tok = for ((counter, (symbol, count)), tok):
((u32, (CountSymbolTestSymbol, CountSymbolTestCount)) , token)
in enumerate(CountSymbolTestTestOutput) {
in std::enumerate(CountSymbolTestTestOutput) {
let last = counter == (array_size(CountSymbolTestTestOutput) - u32:1);
let expected = CountSymbolTestEncOutData{
symbol: symbol, count: count, last: last};
Expand Down Expand Up @@ -264,7 +264,7 @@ proc RunLengthEncoderOverflowTest {
];
let tok = for ((counter, symbol), tok):
((u32, OverflowStimulus) , token)
in enumerate(OverflowTestStimuli) {
in std::enumerate(OverflowTestStimuli) {
let last = counter == (
array_size(OverflowTestStimuli) - u32:1);
let stimulus = OverflowEncInData{symbol: symbol, last: last};
Expand All @@ -284,7 +284,7 @@ proc RunLengthEncoderOverflowTest {
];
let tok = for ((counter, (symbol, count)), tok):
((u32, (OverflowSymbol, OverflowCount)) , token)
in enumerate(OverflowTestOutput) {
in std::enumerate(OverflowTestOutput) {
let last = counter == (array_size(OverflowTestOutput) - u32:1);
let expected = OverflowEncOutData{
symbol: symbol, count: count, last: last};
Expand Down Expand Up @@ -334,7 +334,7 @@ proc RunLengthEncoderLastAfterLastTest {
];
let tok = for ((counter, stimuli), tok):
((u32, LastAfterLastStimulus) , token)
in enumerate(LastAfterLastTestStimuli) {
in std::enumerate(LastAfterLastTestStimuli) {
let tok = send(tok, enc_input_s, stimuli);
trace_fmt!("Sent {} transactions, symbol: 0x{:x}, last: {}",
counter, stimuli.symbol, stimuli.last);
Expand All @@ -352,7 +352,7 @@ proc RunLengthEncoderLastAfterLastTest {
];
let tok = for ((counter, expected), tok):
((u32, LastAfterLastOutput) , token)
in enumerate(LastAfterLastTestOutput) {
in std::enumerate(LastAfterLastTestOutput) {
let (tok, enc_output) = recv(tok, enc_output_r);
trace_fmt!(
"Received {} pairs, symbol: 0x{:x}, count: {}, last: {}",
Expand Down Expand Up @@ -403,7 +403,7 @@ proc RunLengthEncoderOverflowWithLastTest {
];
let tok = for ((counter, symbol), tok):
((u32, OverflowWithLastStimulus) , token)
in enumerate(OverflowWithLastTestStimuli) {
in std::enumerate(OverflowWithLastTestStimuli) {
let last = counter == (
array_size(OverflowWithLastTestStimuli) - u32:1);
let stimulus = OverflowWithLastEncInData{symbol: symbol, last: last};
Expand All @@ -419,7 +419,7 @@ proc RunLengthEncoderOverflowWithLastTest {
];
let tok = for ((counter, (symbol, count)), tok):
((u32, (OverflowWithLastSymbol, OverflowWithLastCount)) , token)
in enumerate(OverflowWithLastTestOutput) {
in std::enumerate(OverflowWithLastTestOutput) {
let last = counter == (array_size(OverflowWithLastTestOutput) - u32:1);
let expected = OverflowWithLastEncOutData{
symbol: symbol, count: count, last: last};
Expand Down
2 changes: 1 addition & 1 deletion xls/modules/zstd/axi_csr_accessor.x
Original file line number Diff line number Diff line change
Expand Up @@ -306,7 +306,7 @@ proc AxiCsrAccessorTest {

next (state: ()) {
// test writing via AXI
let tok = for ((i, test_data), tok): ((u32, TestData), token) in enumerate(TEST_DATA) {
let tok = for ((i, test_data), tok): ((u32, TestData), token) in std::enumerate(TEST_DATA) {
// write CSR via AXI
let axi_aw = TestAxiAw {
id: i as uN[TEST_ID_W],
Expand Down
2 changes: 1 addition & 1 deletion xls/modules/zstd/comp_block_dec.x
Original file line number Diff line number Diff line change
Expand Up @@ -1608,7 +1608,7 @@ proc CompressBlockDecoderTest {
let (input_length, input, output_length, output) = COMP_BLOCK_DEC_TESTCASES[test_i];

trace_fmt!("Loading testcase {}", test_i);
let tok = for ((i, input_data), tok): ((u32, u64), token) in enumerate(input) {
let tok = for ((i, input_data), tok): ((u32, u64), token) in std::enumerate(input) {
let req = TestcaseRamWrReq {
addr: i as uN[TEST_CASE_RAM_ADDR_W],
data: input_data as uN[TEST_CASE_RAM_DATA_W],
Expand Down
4 changes: 2 additions & 2 deletions xls/modules/zstd/comp_lookup_dec.x
Original file line number Diff line number Diff line change
Expand Up @@ -1799,7 +1799,7 @@ proc CompLookupDecoderTest {
let (input, output, resp_ok) = COMP_LOOKUP_DECODER_TESTCASES[test_i];

trace_fmt!("Loading testcase {:x}", test_i);
let tok = for ((i, input_data), tok): ((u32, u64), token) in enumerate(input) {
let tok = for ((i, input_data), tok): ((u32, u64), token) in std::enumerate(input) {
let req = TestcaseRamWrReq {
addr: i as uN[TEST_CASE_RAM_ADDR_WIDTH],
data: input_data as uN[TEST_CASE_RAM_DATA_WIDTH],
Expand All @@ -1818,7 +1818,7 @@ proc CompLookupDecoderTest {
let (tok, resp) = recv(tok, resp_r);
assert_eq(resp, resp_ok);

let tok = for ((i, output_data), tok): ((u32, FseTableRecord), token) in enumerate(output) {
let tok = for ((i, output_data), tok): ((u32, FseTableRecord), token) in std::enumerate(output) {
let req = FseRamRdReq {
addr: i as uN[TEST_FSE_RAM_ADDR_WIDTH],
mask: std::unsigned_max_value<TEST_FSE_RAM_NUM_PARTITIONS>(),
Expand Down
4 changes: 2 additions & 2 deletions xls/modules/zstd/csr_config.x
Original file line number Diff line number Diff line change
Expand Up @@ -294,7 +294,7 @@ proc CsrConfig_test {
let expected_values = zero!<uN[TEST_DATA_W][TEST_REGS_N]>();

// Test Writes through external interface
let (tok, expected_values) = for ((i, test_data), (tok, expected_values)): ((u32, TestData), (token, uN[TEST_DATA_W][TEST_REGS_N])) in enumerate(TEST_DATA) {
let (tok, expected_values) = for ((i, test_data), (tok, expected_values)): ((u32, TestData), (token, uN[TEST_DATA_W][TEST_REGS_N])) in std::enumerate(TEST_DATA) {
// write CSR via external interface
let wr_req = TestCsrWrReq {
csr: test_data.csr,
Expand Down Expand Up @@ -344,7 +344,7 @@ proc CsrConfig_test {
}((join(), expected_values));

// Test writes via internal interface
let (tok, _) = for ((i, test_data), (tok, expected_values)): ((u32, TestData), (token, uN[TEST_DATA_W][TEST_REGS_N])) in enumerate(TEST_DATA) {
let (tok, _) = for ((i, test_data), (tok, expected_values)): ((u32, TestData), (token, uN[TEST_DATA_W][TEST_REGS_N])) in std::enumerate(TEST_DATA) {
// write CSR via request channel
let csr_wr_req = TestCsrWrReq {
csr: test_data.csr,
Expand Down
4 changes: 2 additions & 2 deletions xls/modules/zstd/frame_header_dec.x
Original file line number Diff line number Diff line change
Expand Up @@ -599,13 +599,13 @@ proc FrameHeaderDecoderTest {
];

const ADDR = u16:0x1234;
let tok = for ((_, (test_vec, expected)), tok): ((u32, (u32[TEST_XFERS_FOR_HEADER], FrameHeaderDecoderResp)), token) in enumerate(tests) {
let tok = for ((_, (test_vec, expected)), tok): ((u32, (u32[TEST_XFERS_FOR_HEADER], FrameHeaderDecoderResp)), token) in std::enumerate(tests) {
let tok = send(tok, decode_req_s, FrameHeaderDecoderReq { addr: ADDR });
let (tok, recv_data) = recv(tok, reader_req_r);

assert_eq(recv_data, ReaderReq { addr: ADDR, length: MAX_MAGIC_PLUS_HEADER_LEN as u16 });

let tok = for ((j, word), tok): ((u32, u32), token) in enumerate(test_vec) {
let tok = for ((j, word), tok): ((u32, u32), token) in std::enumerate(test_vec) {
let last = j + u32:1 == array_size(test_vec);
send(tok, reader_resp_s, ReaderResp {
status: mem_reader::MemReaderStatus::OKAY,
Expand Down
14 changes: 7 additions & 7 deletions xls/modules/zstd/fse_dec.x
Original file line number Diff line number Diff line change
Expand Up @@ -1352,7 +1352,7 @@ const TEST_EXPECTED_COMMANDS_1 = CommandConstructorData[14]:[
//
// // write OF table
// let tok = for ((i, of_record), tok): ((u32, u32), token) in
// enumerate(TEST_OF_TABLE[state]) {
// std::enumerate(TEST_OF_TABLE[state]) {
// let tok = send(tok, of_fse_wr_req_s, FseRamWrReq {
// addr: i as u8,
// data: of_record,
Expand All @@ -1364,7 +1364,7 @@ const TEST_EXPECTED_COMMANDS_1 = CommandConstructorData[14]:[
//
// // write ML table
// let tok = for ((i, ml_record), tok): ((u32, u32), token) in
// enumerate(TEST_ML_TABLE[state]) {
// std::enumerate(TEST_ML_TABLE[state]) {
// let tok = send(tok, ml_fse_wr_req_s, FseRamWrReq {
// addr: i as u8,
// data: ml_record,
Expand All @@ -1376,7 +1376,7 @@ const TEST_EXPECTED_COMMANDS_1 = CommandConstructorData[14]:[
//
// // write LL table
// let tok = for ((i, ll_record), tok): ((u32, u32), token) in
// enumerate(TEST_LL_TABLE[state]) {
// std::enumerate(TEST_LL_TABLE[state]) {
// let tok = send(tok, ll_fse_wr_req_s, FseRamWrReq {
// addr: i as u8,
// data: ll_record,
Expand All @@ -1395,7 +1395,7 @@ const TEST_EXPECTED_COMMANDS_1 = CommandConstructorData[14]:[
// // block #0
// // send data
// let tok = for ((i, data), tok): ((u32, RefillingSBOutput), token) in
// enumerate(TEST_DATA_0) {
// std::enumerate(TEST_DATA_0) {
// let (tok, buf_ctrl) = recv(tok, rsb_ctrl_r);
// trace_fmt!("Received #{} buf ctrl {:#x}", i + u32:1, buf_ctrl);
// assert_eq(RefillingSBCtrl {length: data.length}, buf_ctrl);
Expand All @@ -1406,7 +1406,7 @@ const TEST_EXPECTED_COMMANDS_1 = CommandConstructorData[14]:[
//
// // recv commands
// let tok = for ((i, expected_cmd), tok): ((u32, CommandConstructorData), token) in
// enumerate(TEST_EXPECTED_COMMANDS_0) {
// std::enumerate(TEST_EXPECTED_COMMANDS_0) {
// let (tok, cmd) = recv(tok, command_r);
// trace_fmt!("Received #{} cmd {:#x}", i + u32:1, cmd);
// assert_eq(expected_cmd, cmd);
Expand All @@ -1420,7 +1420,7 @@ const TEST_EXPECTED_COMMANDS_1 = CommandConstructorData[14]:[
// // block #1
// // send data
// let tok = for ((i, data), tok): ((u32, RefillingSBOutput), token) in
// enumerate(TEST_DATA_1) {
// std::enumerate(TEST_DATA_1) {
// let (tok, buf_ctrl) = recv(tok, rsb_ctrl_r);
// trace_fmt!("Received #{} buf ctrl {:#x}", i + u32:1, buf_ctrl);
// assert_eq(RefillingSBCtrl {length: data.length}, buf_ctrl);
Expand All @@ -1431,7 +1431,7 @@ const TEST_EXPECTED_COMMANDS_1 = CommandConstructorData[14]:[
//
// // recv commands
// let tok = for ((i, expected_cmd), tok): ((u32, CommandConstructorData), token) in
// enumerate(TEST_EXPECTED_COMMANDS_1) {
// std::enumerate(TEST_EXPECTED_COMMANDS_1) {
// let (tok, cmd) = recv(tok, command_r);
// trace_fmt!("Received #{} cmd {:#x}", i + u32:1, cmd);
// assert_eq(expected_cmd, cmd);
Expand Down
Loading
Loading