diff --git a/src/topi/einsum.cc b/src/topi/einsum.cc index 892a17e58d7f..3e9cd358e929 100644 --- a/src/topi/einsum.cc +++ b/src/topi/einsum.cc @@ -98,24 +98,33 @@ EinsumEquation EinsumEquation::FromString(const std::string& equation) { } PrimExpr GetBroadcastedExtent(const PrimExpr& extent1, const PrimExpr& extent2) { - int64_t extent1_value = GetConstInt(extent1); - int64_t extent2_value = GetConstInt(extent2); - if (extent1_value == extent2_value) { + const IntImmNode* extent1_imm = extent1.as(); + const IntImmNode* extent2_imm = extent2.as(); + if (extent1_imm != nullptr && extent2_imm != nullptr) { + if (extent1_imm->value == extent2_imm->value) { + return extent1; + } else if (extent1_imm->value == 1 || extent2_imm->value == 1) { + return Integer(std::max(extent1_imm->value, extent2_imm->value)); + } + LOG(FATAL) << "Cannot broadcast extents " << extent1 << " and " << extent2; + throw; + } else if (extent1_imm != nullptr) { + return extent2; + } else if (extent2_imm != nullptr) { return extent1; - } else if (extent1_value == 1 || extent2_value == 1) { - return Integer(std::max(extent1_value, extent2_value)); + } else { + return max(extent1, extent2); } - LOG(FATAL) << "Cannot broadcast extents " << extent1 << " and " << extent2; - throw; } PrimExpr GetIndexForBroadcastedDim(const Var& index, const PrimExpr& extent, const PrimExpr& broadcasted_extent) { - if (GetConstInt(extent) == GetConstInt(broadcasted_extent)) { - return index; - } else { - return Integer(0); + // Check if current dimension is being broadcasted to `broadcasted_extent` (symbolic shape is + // handled) + if (is_one(extent) && !is_one(broadcasted_extent)) { + return make_zero(index.dtype()); } + return index; } /*! \brief The compute builder for Einsum */ diff --git a/tests/python/topi/python/test_topi_einsum.py b/tests/python/topi/python/test_topi_einsum.py index d6dc43e4da00..a84cbaffc185 100644 --- a/tests/python/topi/python/test_topi_einsum.py +++ b/tests/python/topi/python/test_topi_einsum.py @@ -23,39 +23,59 @@ from tvm.topi.utils import get_const_tuple -def with_tvm(lam, *args): +def with_tvm(lam, shapes, ops, out_shape): """Take numpy arrays as args, convert them to TVM tensors and call `lam`. Result of lambda is converted back to numpy array and returned. """ dev = tvm.cpu(0) pls = [] # placeholders vals_nd = [] # initial values - for i, arg in enumerate(args): - pls.append(te.placeholder(arg.shape, name="pl" + str(i))) + for i, (shape, arg) in enumerate(zip(shapes, ops)): + pls.append(te.placeholder(shape, name="pl" + str(i))) vals_nd.append(tvm.nd.array(arg, dev)) out = lam(*pls) - out_nd = tvm.nd.array(np.zeros(get_const_tuple(out.shape), dtype=out.dtype), dev) + out_nd = tvm.nd.array(np.zeros(out_shape).astype(out.dtype), device=dev) s = te.create_schedule([out.op]) m = tvm.build(s, pls + [out], "llvm") m(*(vals_nd + [out_nd])) return out_nd.numpy() -def verify_einsum(subscripts, shapes): - ops = [] +def verify_einsum(subscripts, shapes, shape_dict={}): + ops = [] # ndarrays to be used as inputs + symbolic_shapes = [] # shapes to declare the placeholders + name_to_var = {} + + def get_concrete_shape(shape): + return [shape_dict[s] if isinstance(s, str) else s for s in shape] + + def get_symblic_shape_var(name, dtype="int32"): + if name not in name_to_var: + name_to_var[name] = te.var(name, dtype=dtype) + return name_to_var[name] + + def get_symbolic_shape(shape): + return [get_symblic_shape_var(s) if isinstance(s, str) else s for s in shape] + for shape in shapes: - tmp = np.random.uniform(low=-1.0, high=1.0, size=shape).astype(np.float32) + concrete_shape = get_concrete_shape(shape) + tmp = np.random.uniform(low=-1.0, high=1.0, size=concrete_shape).astype(np.float32) ops.append(tmp) + symbolic_shape = get_symbolic_shape(shape) + symbolic_shapes.append(symbolic_shape) c1 = np.einsum(subscripts, *ops) + out_shape = c1.shape if len(ops) == 1: - c2 = with_tvm(lambda A: topi.einsum(subscripts, A), *ops) + c2 = with_tvm(lambda A: topi.einsum(subscripts, A), symbolic_shapes, ops, out_shape) elif len(ops) == 2: - c2 = with_tvm(lambda A, B: topi.einsum(subscripts, A, B), *ops) + c2 = with_tvm(lambda A, B: topi.einsum(subscripts, A, B), symbolic_shapes, ops, out_shape) elif len(ops) == 3: - c2 = with_tvm(lambda A, B, C: topi.einsum(subscripts, A, B, C), *ops) + c2 = with_tvm( + lambda A, B, C: topi.einsum(subscripts, A, B, C), symbolic_shapes, ops, out_shape + ) tvm.testing.assert_allclose(c1, c2, rtol=1e-5, atol=1e-5) @@ -82,5 +102,17 @@ def test_einsum(equation, inputs): verify_einsum(equation, inputs) +@pytest.mark.parametrize( + "equation,inputs,shape_dict", + [ + ("ij,jk->ik", [(2, "K"), (1, "N")], {"K": 3, "N": 4}), + ("ij,jk->ik", [(2, "K"), ("K2", "N")], {"K": 3, "N": 4, "K2": 3}), + ("ij,jk->ik", [(2, "K"), ("K2", "N")], {"K": 3, "N": 4, "K2": 1}), + ], +) +def test_einsum_symblic_shape(equation, inputs, shape_dict): + verify_einsum(equation, inputs, shape_dict) + + if __name__ == "__main__": tvm.testing.main()