Skip to content

[Bug] [Frontend][Tensorflow] tf.where with broadcast condition fails to import due to Incompatible broadcast type #13855

@balaram-cadence

Description

@balaram-cadence

The test case below fails to import in tvm:

def test_forward_where_with_broadcast_cond():
    t1 = np.array([1.0, 2.0, 3.0, 4.0, 5.0]).astype("float32")
    t2 = np.array([2.0, 4.0, 1.0, 3.0, 5.0]).astype("float32")
    x = np.array([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0], [9.0, 10.0]]).astype("float32")
    y = np.array([[10.0, 9.0], [8.0, 7.0], [6.0, 5.0], [4.0, 3.0], [2.0, 1.0]]).astype("float32")

    with tf.Graph().as_default():
        in1 = tf.placeholder(shape=(5), dtype = "float32", name="in1")
        in2 = tf.placeholder(shape=(5), dtype = "float32", name="in2")
        condition = math_ops.less(in1, in2, name="less")
        lhs = tf.placeholder(shape=(5,2), dtype = "float32", name="x")
        rhs = tf.placeholder(shape=(5,2), dtype = "float32", name="y")
        out = tf.where(condition, lhs, rhs)
        compare_tf_with_tvm([t1, t2, x, y], ["in1:0", "in2:0", "x:0", "y:0"], out.name)

Expected behavior

Should be identical to tensorflow output:

[array([[1., 2.],
       [3., 4.],
       [6., 5.],
       [4., 3.],
       [2., 1.]], dtype=float32)]

Actual behavior

Failed with this error:

Incompatible broadcast type TensorType([5], bool) and TensorType([5, 2], float32)

Environment

Linux
LSB Version:    :core-4.1-amd64:core-4.1-ia32:core-4.1-noarch:cxx-4.1-amd64:cxx-4.1-ia32:cxx-4.1-noarch:desktop-4.1-amd64:desktop-4.1-ia32:desktop-4.1-noarch:languages-4.1-amd64:languages-4.1-noarch:printing-4.1-amd64:printing-4.1-noarch
Distributor ID: RedHatEnterpriseWorkstation
Description:    Red Hat Enterprise Linux Workstation release 7.9 (Maipo)
Release:        7.9
Codename:       Maipo
Name: apache-tvm
Version: 0.10.0
Home-page: https://tlcpack.ai
Author: Apache TVM
Author-email: None
License: Apache
Name: tensorflow
Version: 2.5.0
Summary: TensorFlow is an open source machine learning framework for everyone.
Home-page: https://www.tensorflow.org/
Author: Google Inc.
Author-email: packages@tensorflow.org
License: Apache 2.0

Steps to reproduce

Add above testcase to tests/python/frontend/tensorflow/test_forward.py and run
python -m pytest tests/python/frontend/tensorflow/test_forward.py -k test_forward_where_with_broadcast_cond

Triage

  • needs-triage
  • frontend:tensorflow

Metadata

Metadata

Assignees

No one assigned

    Labels

    needs-triagePRs or issues that need to be investigated by maintainers to find the right assignees to address ittype: bug

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions