Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
153 changes: 104 additions & 49 deletions vta/apps/gemm/hardware/chisel/src/main/scala/accel/Compute.scala
Original file line number Diff line number Diff line change
Expand Up @@ -22,21 +22,31 @@ package accel
import chisel3._
import chisel3.util._
import vta.dpi._
import vta.core._
import vta.util.config._
import vta.shell._

class TestConfig extends Config(new CoreConfig ++ new PynqConfig)
/** Compute
*
* Bit Slice GEMM:
*
* 1. Wait for launch to be asserted
* 2. Issue 2 read request for 8-byte value at inp1_baddr address and inp2_baddr address
* 2. Issue 1 read request for 8-bit value at inp1_baddr address (read matrix)
* 3. Wait for the value
* 4. Increment read-address for next value
* 5. Wait for sliced accumulator
* 6. Check if counter (cnt) is equal to length process,
otherwise goto step 2
* 7. Check if reset slice accumulator
* 8. Wait for overall accumulator
* 8. Issue a write request for 8-byte value at out_baddr address
* 5. Repeat until all inp1 data have been read

* 6. Issue 1 read request for 8-bit value at inp2_baddr address (read vector)
* 7. Wait for the value
* 8. Increment read-address for next value
* 9. Repeat until all inp2 data have been read

* 10. Wait for output to be calculated
* 11. Issue a write request for 8-byte value at out_baddr address
* 12. Increment write-address for next value to write
* 13. Check if counter (cntout) is equal to length to asser finish,
otherwise go to step 11
*/
class Compute(implicit config: AccelConfig) extends Module {
val io = IO(new Bundle {
Expand All @@ -47,19 +57,24 @@ class Compute(implicit config: AccelConfig) extends Module {
val ptrs = Input(Vec(config.nPtrs, UInt(config.ptrBits.W)))
val mem = new VTAMemDPIMaster
})
val sIdle :: sReadAReq :: sReadAData :: sReadBReq :: sReadBData :: sWriteReq :: sWriteData :: Nil = Enum(7)
implicit val p: Parameters = new TestConfig
val sIdle :: sReadAReq :: sReadAData :: sReadADone ::sReadBReq :: sReadBData :: sReadBDone :: sInpDone ::sWait:: sWriteReq :: sWriteData :: sWriteDone :: Nil = Enum(12)
val state = RegInit(sIdle)
val shift = io.vals(0)
val length = io.vals(1)
val rstAccum = io.vals(2)
val startDot = io.vals(3)
val cycles = RegInit(0.U(config.regBits.W))
val reg1 = Reg(chiselTypeOf(io.mem.rd.bits))
val reg2 = Reg(chiselTypeOf(io.mem.rd.bits))
val cnt = Reg(UInt(config.regBits.W))
val mvc = Module(new MatrixVectorMultiplication)
val reg1 = Reg(chiselTypeOf(mvc.io.wgt.data.bits))
val reg2 = Reg(chiselTypeOf(mvc.io.inp.data.bits))
val cntwgt = Reg(UInt(config.regBits.W))
val cntinp = Reg(UInt(config.regBits.W))
val cntout = Reg(UInt(config.regBits.W))
val raddr1 = Reg(UInt(config.ptrBits.W))
val raddr2 = Reg(UInt(config.ptrBits.W))
val waddr = Reg(UInt(config.ptrBits.W))
val accum = Module(new Accmulator(size = p(CoreKey).blockOut, accBits = p(CoreKey).accBits))

switch (state) {
is (sIdle) {
Expand All @@ -73,14 +88,38 @@ class Compute(implicit config: AccelConfig) extends Module {
}
is (sReadAData) {
when (io.mem.rd.valid) {
state := sReadADone
}
}
is (sReadADone) {
when (cntwgt === (length * length) - 1.U) {
state := sReadBReq
} .otherwise {
state := sReadAReq
}
}
is (sReadBReq) {
state := sReadBData
}
is (sReadBData) {
when (io.mem.rd.valid) {
state := sReadBDone
}
}
is (sReadBDone) {
when (cntinp === length-1.U) {
state := sInpDone
} .otherwise {
state := sReadBReq
}
}
// Both input is processed
is (sInpDone) {
state := sWait
}
// Wait for computation
is (sWait) {
when (accum.io.ready) {
state := sWriteReq
}
}
Expand All @@ -89,15 +128,18 @@ class Compute(implicit config: AccelConfig) extends Module {
state := sWriteData
}
is (sWriteData) {
when (cnt === (length - 1.U)) {
state := sWriteDone
}
is (sWriteDone) {
when (cntout === (length - 1.U)) {
state := sIdle
} .otherwise {
state := sReadAReq
state := sWriteReq
}
}
}

val last = state === sWriteData && cnt === (length - 1.U)
val last = state === sWriteDone && cntout === (length - 1.U)

// cycle counter
when (state === sIdle) {
Expand All @@ -114,10 +156,12 @@ class Compute(implicit config: AccelConfig) extends Module {
raddr1 := io.ptrs(0)
raddr2 := io.ptrs(1)
waddr := io.ptrs(2)
} .elsewhen (state === sWriteData) { // increment input array by 1-byte
} .elsewhen (state === sReadADone) { // increment input array by 1-byte
raddr1 := raddr1 + 1.U
} .elsewhen (state === sReadBDone) { // increment input array by 1-byte
raddr2 := raddr2 + 1.U
waddr := waddr
} .elsewhen (state === sWriteDone) {
waddr := waddr + 4.U // writing 4 bytes
}

// create request
Expand All @@ -128,59 +172,70 @@ class Compute(implicit config: AccelConfig) extends Module {

// read
when (state === sReadAData && io.mem.rd.valid) {
reg1 := io.mem.rd.bits(7, 0)
reg1(cntwgt/length)(cntwgt%length) := io.mem.rd.bits(7, 0)
}

when (state === sReadBData && io.mem.rd.valid) {
reg2 := io.mem.rd.bits(7, 0)
reg2(0)(cntinp) := io.mem.rd.bits(7, 0)
}

io.mem.rd.ready := state === sReadAData | state === sReadBData
mvc.io.inp.data.valid := state === sInpDone // 2 inputs have been processed
mvc.io.wgt.data.valid := state === sInpDone // 2 inputs have been processed

mvc.io.wgt.data.bits <> reg1
mvc.io.inp.data.bits <> reg2
// Modify when shift operation is supported
mvc.io.reset := false.B
mvc.io.acc_i.data.valid := true.B
for (i <- 0 until p(CoreKey).blockOut) {
mvc.io.acc_i.data.bits(0)(i) := 0.U
}


val sliceAccum = Module(new Accumulator(63))
val overallAccum = Module(new Accumulator(64))

sliceAccum.io.valid := state === sWriteReq // 2 inputs have been processed
sliceAccum.io.in := reg1 * reg2
sliceAccum.io.clear := startDot
overallAccum.io.clear := rstAccum
overallAccum.io.valid := last // last element has been processed
overallAccum.io.in := sliceAccum.io.sum << shift(7,0) // limit to 8 bits
accum.io.in := mvc.io.acc_o.data.bits
accum.io.shift := shift
accum.io.clear := rstAccum
accum.io.valid := mvc.io.acc_o.data.valid

// write
io.mem.wr.valid := overallAccum.io.ready
io.mem.wr.bits := overallAccum.io.sum

io.mem.wr.valid := state === sWriteData
io.mem.wr.bits := accum.io.sum(cntout)

// count read/write
when (state === sIdle) {
cnt := 0.U
} .elsewhen (state === sWriteData) {
cnt := cnt + 1.U
cntwgt := 0.U
cntinp := 0.U
cntout := 0.U
} .elsewhen (state === sReadADone) {
cntwgt := cntwgt + 1.U
} .elsewhen (state === sReadBDone) {
cntinp := cntinp + 1.U
} .elsewhen (state === sWriteDone) {
cntout := cntout + 1.U
}

io.finish := overallAccum.io.ready // data has been added
io.finish := last // data has been added
}


class Accumulator(dataBits: Int = 8) extends Module {
// Shift operation until supported in MVM
class Accmulator(size: Int = 16, accBits: Int = 32) extends Module {
val io = IO(new Bundle {
val clear = Input(Bool())
val valid = Input(Bool())
val ready = Output(Bool())
val in = Input(UInt(dataBits.W))
val sum = Output(UInt((dataBits).W))
val in = Input(Vec(1, Vec(size, (UInt(accBits.W)))))
val shift = Input(UInt(8.W))
val sum = Output(Vec(size, (UInt(accBits.W))))
})
val reg = RegInit(VecInit(Seq.fill(size)(0.U(accBits.W))))

val reg = RegInit(0.U((dataBits).W))
val ready = RegNext(io.valid)
when (io.clear) {
reg := 0.U
} .elsewhen (io.valid) {
reg := reg + io.in
}
io.ready := ready
io.sum := reg
for (i <- 0 until size) {
when (io.clear) {
reg(i) := 0.U
} .elsewhen(io.valid) {
reg(i) := reg(i) + (io.in(0)(i) << io.shift)
}
}
io.ready := RegNext(io.valid)
io.sum := reg
}

10 changes: 3 additions & 7 deletions vta/apps/gemm/hardware/chisel/src/main/scala/accel/RegFile.scala
Original file line number Diff line number Diff line change
Expand Up @@ -35,13 +35,9 @@ import vta.dpi._
* Shift value | 0x08
* Vector length | 0x0c
* Reset Accumulator | 0x10
* Reset Dot Module | 0x14
* Input1 pointer lsb | 0x18
* Input1 pointer msb | 0x1c
* Input2 pointer lsb | 0x20
* Input2 pointer msb | 0x24
* Output pointer lsb | 0x28
* Output pointer msb | 0x2c
* Input1 pointer | 0x18
* Input2 pointer | 0x20
* Output pointer | 0x28
* -------------------------------

* ------------------------------
Expand Down
18 changes: 9 additions & 9 deletions vta/apps/gemm/src/driver.cc
Original file line number Diff line number Diff line change
Expand Up @@ -66,10 +66,12 @@ class Device {

uint32_t Run(DLTensor* inp1, DLTensor* inp2, uint32_t shiftVal, DLTensor* out, uint32_t reset) {
uint32_t cycles;
uint32_t length = inp1->shape[0];
size_t size1 = (inp1->dtype.bits >> 3) * length;
uint32_t length = inp2->shape[0];
// 1 matrix 1 vector input
size_t size1 = (inp1->dtype.bits >> 3) * length * length;
size_t size2 = (inp2->dtype.bits >> 3) * length;
size_t size3 = (64 >> 3);
// 1 vector output
size_t size3 = (32 >> 3) * length;
inp1_ = this->MemAlloc(size1);
inp2_ = this->MemAlloc(size2);
out_ = this->MemAlloc(size3);
Expand Down Expand Up @@ -115,19 +117,17 @@ class Device {

void Launch(uint32_t length, uint32_t shiftVal, uint32_t reset) {
dpi_->WriteReg(0x08, shiftVal);
dpi_->WriteReg(0x0c, length); // vector length
dpi_->WriteReg(0x0c, length); // tensor size
dpi_->WriteReg(0x18, this->MemGetPhyAddr(inp1_));
dpi_->WriteReg(0x20, this->MemGetPhyAddr(inp2_));
dpi_->WriteReg(0x28, this->MemGetPhyAddr(out_));
dpi_->WriteReg(0x00, 0x1); // launch
dpi_->WriteReg(0x00, 0x0); // launch
dpi_->WriteReg(0x00, 0x0);

if (reset == 1) {
dpi_->WriteReg(0x10, 0x1); // reset accum
dpi_->WriteReg(0x10, 0x0); // stop reset accum
dpi_->WriteReg(0x10, 0x1); // reset accumulator
dpi_->WriteReg(0x10, 0x0);
}
dpi_->WriteReg(0x14, 0x1); // reset dot
dpi_->WriteReg(0x14, 0x0); // stop reset dot
}

uint32_t WaitForCompletion() {
Expand Down
Loading