Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 20 additions & 11 deletions src/topi/einsum.cc
Original file line number Diff line number Diff line change
Expand Up @@ -98,24 +98,33 @@ EinsumEquation EinsumEquation::FromString(const std::string& equation) {
}

PrimExpr GetBroadcastedExtent(const PrimExpr& extent1, const PrimExpr& extent2) {
int64_t extent1_value = GetConstInt(extent1);
int64_t extent2_value = GetConstInt(extent2);
if (extent1_value == extent2_value) {
const IntImmNode* extent1_imm = extent1.as<IntImmNode>();
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One minor note for future, would be useful to have common broadcast extent handling in the future

const IntImmNode* extent2_imm = extent2.as<IntImmNode>();
if (extent1_imm != nullptr && extent2_imm != nullptr) {
if (extent1_imm->value == extent2_imm->value) {
return extent1;
} else if (extent1_imm->value == 1 || extent2_imm->value == 1) {
return Integer(std::max(extent1_imm->value, extent2_imm->value));
}
LOG(FATAL) << "Cannot broadcast extents " << extent1 << " and " << extent2;
throw;
} else if (extent1_imm != nullptr) {
return extent2;
} else if (extent2_imm != nullptr) {
return extent1;
} else if (extent1_value == 1 || extent2_value == 1) {
return Integer(std::max(extent1_value, extent2_value));
} else {
return max(extent1, extent2);
}
LOG(FATAL) << "Cannot broadcast extents " << extent1 << " and " << extent2;
throw;
}

PrimExpr GetIndexForBroadcastedDim(const Var& index, const PrimExpr& extent,
const PrimExpr& broadcasted_extent) {
if (GetConstInt(extent) == GetConstInt(broadcasted_extent)) {
return index;
} else {
return Integer(0);
// Check if current dimension is being broadcasted to `broadcasted_extent` (symbolic shape is
// handled)
if (is_one(extent) && !is_one(broadcasted_extent)) {
return make_zero(index.dtype());
}
return index;
}

/*! \brief The compute builder for Einsum */
Expand Down
52 changes: 42 additions & 10 deletions tests/python/topi/python/test_topi_einsum.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,39 +23,59 @@
from tvm.topi.utils import get_const_tuple


def with_tvm(lam, *args):
def with_tvm(lam, shapes, ops, out_shape):
"""Take numpy arrays as args, convert them to TVM tensors and call `lam`.
Result of lambda is converted back to numpy array and returned.
"""
dev = tvm.cpu(0)
pls = [] # placeholders
vals_nd = [] # initial values
for i, arg in enumerate(args):
pls.append(te.placeholder(arg.shape, name="pl" + str(i)))
for i, (shape, arg) in enumerate(zip(shapes, ops)):
pls.append(te.placeholder(shape, name="pl" + str(i)))
vals_nd.append(tvm.nd.array(arg, dev))

out = lam(*pls)
out_nd = tvm.nd.array(np.zeros(get_const_tuple(out.shape), dtype=out.dtype), dev)
out_nd = tvm.nd.array(np.zeros(out_shape).astype(out.dtype), device=dev)
s = te.create_schedule([out.op])
m = tvm.build(s, pls + [out], "llvm")
m(*(vals_nd + [out_nd]))
return out_nd.numpy()


def verify_einsum(subscripts, shapes):
ops = []
def verify_einsum(subscripts, shapes, shape_dict={}):
ops = [] # ndarrays to be used as inputs
symbolic_shapes = [] # shapes to declare the placeholders
name_to_var = {}

def get_concrete_shape(shape):
return [shape_dict[s] if isinstance(s, str) else s for s in shape]

def get_symblic_shape_var(name, dtype="int32"):
if name not in name_to_var:
name_to_var[name] = te.var(name, dtype=dtype)
return name_to_var[name]

def get_symbolic_shape(shape):
return [get_symblic_shape_var(s) if isinstance(s, str) else s for s in shape]

for shape in shapes:
tmp = np.random.uniform(low=-1.0, high=1.0, size=shape).astype(np.float32)
concrete_shape = get_concrete_shape(shape)
tmp = np.random.uniform(low=-1.0, high=1.0, size=concrete_shape).astype(np.float32)
ops.append(tmp)
symbolic_shape = get_symbolic_shape(shape)
symbolic_shapes.append(symbolic_shape)

c1 = np.einsum(subscripts, *ops)
out_shape = c1.shape

if len(ops) == 1:
c2 = with_tvm(lambda A: topi.einsum(subscripts, A), *ops)
c2 = with_tvm(lambda A: topi.einsum(subscripts, A), symbolic_shapes, ops, out_shape)
elif len(ops) == 2:
c2 = with_tvm(lambda A, B: topi.einsum(subscripts, A, B), *ops)
c2 = with_tvm(lambda A, B: topi.einsum(subscripts, A, B), symbolic_shapes, ops, out_shape)
elif len(ops) == 3:
c2 = with_tvm(lambda A, B, C: topi.einsum(subscripts, A, B, C), *ops)
c2 = with_tvm(
lambda A, B, C: topi.einsum(subscripts, A, B, C), symbolic_shapes, ops, out_shape
)

tvm.testing.assert_allclose(c1, c2, rtol=1e-5, atol=1e-5)

Expand All @@ -82,5 +102,17 @@ def test_einsum(equation, inputs):
verify_einsum(equation, inputs)


@pytest.mark.parametrize(
"equation,inputs,shape_dict",
[
("ij,jk->ik", [(2, "K"), (1, "N")], {"K": 3, "N": 4}),
("ij,jk->ik", [(2, "K"), ("K2", "N")], {"K": 3, "N": 4, "K2": 3}),
("ij,jk->ik", [(2, "K"), ("K2", "N")], {"K": 3, "N": 4, "K2": 1}),
],
)
def test_einsum_symblic_shape(equation, inputs, shape_dict):
verify_einsum(equation, inputs, shape_dict)


if __name__ == "__main__":
tvm.testing.main()