Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 14 additions & 15 deletions vta/hardware/chisel/src/main/scala/core/TensorLoad.scala
Original file line number Diff line number Diff line change
Expand Up @@ -103,20 +103,21 @@ class TensorLoad(tensorType: String = "none", debug: Boolean = false)(
when(dec.xpad_1 =/= 0.U) {
state := sXPad1
}.elsewhen(dec.ypad_1 =/= 0.U) {
state := sYPad1
}
.otherwise {
state := sIdle
}
}.elsewhen(dataCtrl.io.stride || dataCtrl.io.split) {
state := sYPad1
}
.otherwise {
state := sIdle
}
}.elsewhen(dataCtrl.io.stride) {
when(dec.xpad_1 =/= 0.U) {
state := sXPad1
}.elsewhen(dec.xpad_0 =/= 0.U) {
state := sXPad0
}
.otherwise {
state := sReadCmd
}
state := sXPad0
}.otherwise {
state := sReadCmd
}
}.elsewhen(dataCtrl.io.split) {
state := sReadCmd
}
}
}
Expand Down Expand Up @@ -168,13 +169,11 @@ class TensorLoad(tensorType: String = "none", debug: Boolean = false)(
xPadCtrl0.io.start := dec.xpad_0 =/= 0.U &
((state === sIdle & io.start) |
(state === sYPad0 & yPadCtrl0.io.done) |
(io.vme_rd.data
.fire() & ~dataCtrlDone & (dataCtrl.io.stride | dataCtrl.io.split) & dec.xpad_1 === 0.U) |
(io.vme_rd.data.fire() & ~dataCtrlDone & dataCtrl.io.stride & dec.xpad_1 === 0.U) |
(state === sXPad1 & xPadCtrl1.io.done & ~dataCtrlDone))

xPadCtrl1.io.start := dec.xpad_1 =/= 0.U & io.vme_rd.data.fire() &
((dataCtrl.io.done) |
(~dataCtrl.io.done & (dataCtrl.io.stride | dataCtrl.io.split) & dec.xpad_1 =/= 0.U))
((dataCtrl.io.done) | (~dataCtrl.io.done & dataCtrl.io.stride & dec.xpad_1 =/= 0.U))

yPadCtrl0.io.inst := io.inst
yPadCtrl1.io.inst := io.inst
Expand Down
116 changes: 61 additions & 55 deletions vta/tests/python/unittest/test_vta_insn.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import vta.testing
from vta.testing import simulator

np.random.seed(0xdeadb)

def test_save_load_out():
"""Test save/store output command"""
Expand Down Expand Up @@ -88,68 +89,73 @@ def _run(env, remote):
def test_padded_load():
"""Test padded load."""
def _run(env, remote):
# declare
n = 3
m = 5
pad_before = [2, 1, 0, 0]
pad_after = [1, 2, 0, 0]
x = tvm.placeholder(
(n, m, env.BATCH, env.BLOCK_OUT),
name="x",
dtype=env.acc_dtype)
x_buf = topi.nn.pad(x, pad_before, pad_after, name="y")
# insert no-op that won't be optimized away
y_buf = tvm.compute((n + pad_before[0] + pad_after[0],
def check_padded_load(pad_before, pad_after, test_name=None):
# declare
n = 3
m = 5
x = tvm.placeholder(
(n, m, env.BATCH, env.BLOCK_OUT),
name="x",
dtype=env.acc_dtype)
x_buf = topi.nn.pad(x, pad_before, pad_after, name="y")
# insert no-op that won't be optimized away
y_buf = tvm.compute((n + pad_before[0] + pad_after[0],
m + pad_before[1] + pad_after[1],
env.BATCH,
env.BLOCK_OUT), lambda *i: x_buf(*i)>>0, "y_buf")
y = tvm.compute((n + pad_before[0] + pad_after[0],
m + pad_before[1] + pad_after[1],
env.BATCH,
env.BLOCK_OUT), lambda *i: x_buf(*i)>>0, "y_buf")
y = tvm.compute((n + pad_before[0] + pad_after[0],
m + pad_before[1] + pad_after[1],
env.BATCH,
env.BLOCK_OUT), lambda *i: y_buf(*i).astype(env.inp_dtype), "y")
# schedule
s = tvm.create_schedule(y.op)
s[x_buf].set_scope(env.acc_scope)
s[x_buf].pragma(x_buf.op.axis[0], env.dma_copy)
s[y_buf].set_scope(env.acc_scope)
s[y_buf].pragma(y_buf.op.axis[0], env.alu)
s[y].pragma(y.op.axis[0], env.dma_copy)
# build
with vta.build_config():
mod = vta.build(s, [x, y], "ext_dev", env.target_host)
env.BLOCK_OUT), lambda *i: y_buf(*i).astype(env.inp_dtype), "y")
# schedule
s = tvm.create_schedule(y.op)
s[x_buf].set_scope(env.acc_scope)
s[x_buf].pragma(x_buf.op.axis[0], env.dma_copy)
s[y_buf].set_scope(env.acc_scope)
s[y_buf].pragma(y_buf.op.axis[0], env.alu)
s[y].pragma(y.op.axis[0], env.dma_copy)
# build
with vta.build_config():
mod = vta.build(s, [x, y], "ext_dev", env.target_host)

if not remote:
return
temp = util.tempdir()
mod.save(temp.relpath("padded_load.o"))
remote.upload(temp.relpath("padded_load.o"))
f = remote.load_module("padded_load.o")
# verify
ctx = remote.ext_dev(0)
x_np = np.random.randint(-10, 10, size=(
n, m, env.BATCH, env.BLOCK_OUT)).astype(x.dtype)
y_np = np.zeros((n + pad_before[0] + pad_after[0],
m + pad_before[1] + pad_after[1],
env.BATCH,
env.BLOCK_OUT)).astype(y.dtype)
y_np[pad_before[0]:pad_before[0] + n,
pad_before[1]:pad_before[1] + m,
:] = x_np
x_nd = tvm.nd.array(x_np, ctx)
y_nd = tvm.nd.empty(y_np.shape, ctx=ctx, dtype=y_np.dtype)
if not remote:
return
temp = util.tempdir()
mod.save(temp.relpath("padded_load.o"))
remote.upload(temp.relpath("padded_load.o"))
f = remote.load_module("padded_load.o")
# verify
ctx = remote.ext_dev(0)
x_np = np.random.randint(0, 10, size=(
n, m, env.BATCH, env.BLOCK_OUT)).astype(x.dtype)
y_np = np.zeros((n + pad_before[0] + pad_after[0],
m + pad_before[1] + pad_after[1],
env.BATCH,
env.BLOCK_OUT)).astype(y.dtype)
y_np[pad_before[0]:pad_before[0] + n,
pad_before[1]:pad_before[1] + m,
:] = x_np
x_nd = tvm.nd.array(x_np, ctx)
y_nd = tvm.nd.empty(y_np.shape, ctx=ctx, dtype=y_np.dtype)

if env.TARGET in ["sim", "tsim"]:
simulator.clear_stats()
if env.TARGET in ["sim", "tsim"]:
simulator.clear_stats()

f(x_nd, y_nd)
f(x_nd, y_nd)

np.testing.assert_equal(y_np, y_nd.asnumpy())
np.testing.assert_equal(y_np, y_nd.asnumpy())

if env.TARGET in ["sim", "tsim"]:
sim_stats = simulator.stats()
print("Padded load execution statistics:")
for k, v in sim_stats.items():
print("\t{:<16}: {:>16}".format(k, v))
if env.TARGET in ["sim", "tsim"]:
sim_stats = simulator.stats()
print("Padded {} load execution statistics:".format(test_name))
for k, v in sim_stats.items():
print("\t{:<16}: {:>16}".format(k, v))

check_padded_load([2, 0, 0, 0], [0, 0, 0, 0], test_name="Y0")
check_padded_load([0, 2, 0, 0], [0, 0, 0, 0], test_name="Y1")
check_padded_load([0, 0, 0, 0], [2, 0, 0, 0], test_name="X0")
check_padded_load([0, 0, 0, 0], [0, 2, 0, 0], test_name="X1")
check_padded_load([1, 1, 0, 0], [1, 1, 0, 0], test_name="all")

vta.testing.run(_run)

Expand Down