diff --git a/python/tvm/relay/frontend/tensorflow.py b/python/tvm/relay/frontend/tensorflow.py index e44653ff1ba9..ceb9c6abd5de 100644 --- a/python/tvm/relay/frontend/tensorflow.py +++ b/python/tvm/relay/frontend/tensorflow.py @@ -97,6 +97,23 @@ def _need_module_for_shape_inference(op): def _need_prelude_for_shape_inference(op): return "TensorArray" in op +def _extract_tensor_data(prelude, rank, dtype, source): + get_tensor_name = 'get_tensor{}'.format(rank) + get_tensor_func = prelude.get_var(get_tensor_name, dtype) + return get_tensor_func(source) + +def _get_tensor_shape(tensor, mod): + if isinstance(tensor, tvm.relay.expr.Var): + return tensor.type_annotation.shape + # TODO: Run infer_type to get the tensor rank + raise Exception("TODO") + +def _get_tensor_rank(tensor, mod): + if isinstance(tensor, tvm.relay.expr.Var): + return len(tensor.type_annotation.shape) + # TODO: Run infer_type to get the tensor rank + raise Exception("TODO") + def _rsqrt(): def _impl(inputs, attr, params): inputs.append(tvm.relay.const(-0.5, attr['T'].name)) @@ -538,17 +555,31 @@ def _impl(inputs, attr, params, prelude): def _tensor_array_scatter(): def _impl(inputs, attr, params, prelude): dtype_str = attr.get('T').name - values_rank = len(inputs[2].type_annotation.shape) + values_shape = _get_tensor_shape(inputs[2], prelude.mod) + values_rank = len(values_shape) unstack_name = "tensor_array_unstack_tensor{}".format(values_rank) unstack_function = prelude.get_var(unstack_name, dtype_str) values = unstack_function(inputs[2]) + tensor_array_scatter_func = prelude.get_var('tensor_array_scatter', dtype_str) - return tensor_array_scatter_func(inputs[0], inputs[1], values) + result = tensor_array_scatter_func(inputs[0], inputs[1], values) + _tensor_array_shape_tracker.trace(result, tuple(values_shape[1:])) + return result return _impl def _tensor_array_gather(): def _impl(inputs, attr, params, prelude): - return prelude.tensor_array_gather(inputs[2], inputs[1]) + dtype_str = attr.get('dtype').name + tensor_array_gather_func = prelude.get_var('tensor_array_gather', dtype_str) + result = tensor_array_gather_func(inputs[2], inputs[1]) + + shape = _tensor_array_shape_tracker.get_shape(inputs[2]) + if shape is not None: + return _extract_tensor_data(prelude, + len(shape) + 1, + dtype_str, + result) + return result return _impl def _tensor_array_size(): @@ -558,7 +589,8 @@ def _impl(inputs, attr, params, prelude): def _tensor_array_write(): def _impl(inputs, attr, params, prelude): - input_rank = len(inputs[2].type_annotation.shape) + input_shape = inputs[2].type_annotation.shape + input_rank = len(input_shape) dtype = attr.get('T').name tensor_name = 'tensor{}'.format(input_rank) @@ -566,23 +598,42 @@ def _impl(inputs, attr, params, prelude): v = tensor_func(inputs[2]) write_func = prelude.get_var('tensor_array_write', dtype) - return write_func(inputs[3], _op.take(inputs[1], tvm.relay.const(0)), v) + res = write_func(inputs[3], _op.take(inputs[1], tvm.relay.const(0)), v) + _tensor_array_shape_tracker.trace(inputs[3], input_shape) + _tensor_array_shape_tracker.union(inputs[3], res) + return res return _impl def _tensor_array_read(): def _impl(inputs, attr, params, prelude): read_func = prelude.get_var('tensor_array_read', attr.get('dtype').name) + rank = len(_tensor_array_shape_tracker.get_shape(inputs[2])) + if rank is not None: + # Optimize for the case that all the tensors in the tensor array are of the same rank, + # we can emit code to extract the tensor out of the ADT object at compile time. + source = read_func(inputs[2], _op.take(inputs[1], tvm.relay.const(0))) + return _extract_tensor_data(prelude, + rank, + attr['dtype'].name, + source) + # The tensor array contains various ranks of tensor, it should be the rare case return read_func(inputs[2], _op.take(inputs[1], tvm.relay.const(0))) return _impl def _tensor_array_split(): def _impl(inputs, attr, params, prelude): - input_rank = len(inputs[1].type_annotation.shape) + input_shape = _get_tensor_shape(inputs[1], prelude.mod) + input_rank = len(input_shape) dtype_str = attr.get('T').name v = prelude.get_var("tensor{}".format(input_rank), dtype_str)(inputs[1]) lengths = _op.cast(inputs[2], 'int32') - split_var = prelude.get_var('tensor_array_split', dtype_str) - return split_var(inputs[0], v, lengths) + split_func = prelude.get_var('tensor_array_split', dtype_str) + result = split_func(inputs[0], v, lengths) + + # TODO: How can we calculate the result shape here + result_shape = [None] + input_shape[1:] + _tensor_array_shape_tracker.trace(result, tuple(result_shape)) + return result return _impl def _tensor_array_concat(): @@ -1371,6 +1422,43 @@ def _impl(inputs, attr, params): return _res return _impl +class TensorArrayShapeTracker(): + def __init__(self): + self.expr_shapes_map_ = defaultdict(set) + self.expr_union_find = {} + + def find(self, expr): + if expr not in self.expr_union_find: + self.expr_union_find[expr] = expr + while self.expr_union_find[expr] != expr: + expr = self.expr_union_find[expr] + return expr + + def union(self, a, b): + group_a = self.find(a) + group_b = self.find(b) + if group_a != group_b: + self.expr_union_find[group_a] = group_b + self.expr_shapes_map_[group_b] = self.expr_shapes_map_[group_b].union(self.expr_shapes_map_[group_a]) + + def trace(self, expr, shape): + self.expr_shapes_map_[expr].add(shape) + + def get_shape(self, expr): + shapes = self.expr_shapes_map_[self.find(expr)] + # TODO: Fix Hack, use string to dedup shapes + shapes_str = set(str(x) for x in shapes) + if shapes is not None: + if len(shapes_str) == 1: + return list(set(shapes))[0] + else: + print("shapes {}".format(shapes)) + raise None + return None + +# Remember the rank of tensors stored in each tensor array +_tensor_array_ranks = defaultdict(set) +_tensor_array_shape_tracker = TensorArrayShapeTracker() # compatible operators that do NOT require any conversion. _identity_list = [] diff --git a/python/tvm/relay/prelude.py b/python/tvm/relay/prelude.py index ddb9302b0810..763b02cf4121 100644 --- a/python/tvm/relay/prelude.py +++ b/python/tvm/relay/prelude.py @@ -138,6 +138,134 @@ def define_tensor_take(self): tensor6_case], False), tensor_t(), []) + def define_get_tensor0(self): + """Defines a function to extract a rank 0 tensor out of tensor_t. + get_tensor0(t): tensor_t -> Tensor[(), dtype] + """ + get_tensor0_name = self.get_name("get_tensor0") + get_tensor0_var = GlobalVar(get_tensor0_name) + setattr(self.prelude, get_tensor0_name, get_tensor0_var) + tensor_t = self.get_var('tensor_t') + tensor_var = self.get_var('tensor0') + t = Var('tensor', tensor_t()) + tvar = Var('t') + case =\ + Clause(PatternConstructor(tensor_var, [PatternVar(tvar)]), tvar) + self.prelude.mod[get_tensor0_var] =\ + Function([t], + Match(t, [case], False), + TensorType([], self.dtype), []) + + def define_get_tensor1(self): + """Defines a function to extract a rank 1 tensor out of tensor_t. + get_tensor1(t): tensor_t -> Tensor[(Any()), dtype] + """ + get_tensor1_name = self.get_name("get_tensor1") + get_tensor1_var = GlobalVar(get_tensor1_name) + setattr(self.prelude, get_tensor1_name, get_tensor1_var) + tensor_t = self.get_var('tensor_t') + tensor_var = self.get_var('tensor1') + t = Var('tensor', tensor_t()) + tvar = Var('t') + case =\ + Clause(PatternConstructor(tensor_var, [PatternVar(tvar)]), tvar) + self.prelude.mod[get_tensor1_var] =\ + Function([t], + Match(t, [case], False), + TensorType([Any()], self.dtype), []) + + def define_get_tensor2(self): + """Defines a function to extract a rank 2 tensor out of tensor_t. + get_tensor2(t): tensor_t -> Tensor[(Any(), Any()), dtype] + """ + get_tensor2_name = self.get_name("get_tensor2") + get_tensor2_var = GlobalVar(get_tensor2_name) + setattr(self.prelude, get_tensor2_name, get_tensor2_var) + tensor_t = self.get_var('tensor_t') + tensor_var = self.get_var('tensor2') + t = Var('tensor', tensor_t()) + tvar = Var('t') + case =\ + Clause(PatternConstructor(tensor_var, [PatternVar(tvar)]), tvar) + self.prelude.mod[get_tensor2_var] =\ + Function([t], + Match(t, [case], False), + TensorType([Any(), Any()], self.dtype), []) + + def define_get_tensor3(self): + """Defines a function to extract a rank 3 tensor out of tensor_t. + get_tensor3(t): tensor_t -> Tensor[(Any(), Any(), Any()), dtype] + """ + get_tensor3_name = self.get_name("get_tensor3") + get_tensor3_var = GlobalVar(get_tensor3_name) + setattr(self.prelude, get_tensor3_name, get_tensor3_var) + tensor_t = self.get_var('tensor_t') + tensor_var = self.get_var('tensor3') + t = Var('tensor', tensor_t()) + tvar = Var('t') + case =\ + Clause(PatternConstructor(tensor_var, [PatternVar(tvar)]), tvar) + self.prelude.mod[get_tensor3_var] =\ + Function([t], + Match(t, [case], False), + TensorType([Any(), Any(), Any()], self.dtype), []) + + def define_get_tensor4(self): + """Defines a function to extract a rank 4 tensor out of tensor_t. + get_tensor4(t): tensor_t -> Tensor[(Any(), Any(), Any(), Any()), dtype] + """ + get_tensor4_name = self.get_name("get_tensor4") + get_tensor4_var = GlobalVar(get_tensor4_name) + setattr(self.prelude, get_tensor4_name, get_tensor4_var) + tensor_t = self.get_var('tensor_t') + tensor_var = self.get_var('tensor4') + t = Var('tensor', tensor_t()) + tvar = Var('t') + case =\ + Clause(PatternConstructor(tensor_var, [PatternVar(tvar)]), tvar) + self.prelude.mod[get_tensor4_var] =\ + Function([t], + Match(t, [case], False), + TensorType([Any(), Any(), Any(), Any()], self.dtype), []) + + def define_get_tensor5(self): + """Defines a function to extract a rank 5 tensor out of tensor_t. + get_tensor5(t): + tensor_t -> Tensor[(Any(), Any(), Any(), Any(), Any()), dtype] + """ + get_tensor5_name = self.get_name("get_tensor5") + get_tensor5_var = GlobalVar(get_tensor5_name) + setattr(self.prelude, get_tensor5_name, get_tensor5_var) + tensor_t = self.get_var('tensor_t') + tensor_var = self.get_var('tensor5') + t = Var('tensor', tensor_t()) + tvar = Var('t') + case =\ + Clause(PatternConstructor(tensor_var, [PatternVar(tvar)]), tvar) + self.prelude.mod[get_tensor5_var] =\ + Function([t], + Match(t, [case], False), + TensorType([Any(), Any(), Any(), Any(), Any()], self.dtype), []) + + def define_get_tensor6(self): + """Defines a function to extract a rank 6 tensor out of tensor_t. + get_tensor6(t): + tensor_t -> Tensor[(Any(), Any(), Any(), Any(), Any(), Any()), dtype] + """ + get_tensor6_name = self.get_name("get_tensor6") + get_tensor6_var = GlobalVar(get_tensor6_name) + setattr(self.prelude, get_tensor6_name, get_tensor6_var) + tensor_t = self.get_var('tensor_t') + tensor_var = self.get_var('tensor6') + t = Var('tensor', tensor_t()) + tvar = Var('t') + case =\ + Clause(PatternConstructor(tensor_var, [PatternVar(tvar)]), tvar) + self.prelude.mod[get_tensor6_var] =\ + Function([t], + Match(t, [case], False), + TensorType([Any(), Any(), Any(), Any(), Any(), Any()], self.dtype), []) + def define_tensor_expand_dims(self): """Defines a function to grow a tensor_t's rank by adding one dimension in front of the original tensor_t. @@ -640,8 +768,14 @@ def register(self): self.define_tensor_array_split() self.define_tensor_array_concat() self.define_tensor_array_stack() - # TODO(wweic): Gather fails in PartialEvaluate - # self.define_tensor_array_gather() + self.define_get_tensor0() + self.define_get_tensor1() + self.define_get_tensor2() + self.define_get_tensor3() + self.define_get_tensor4() + self.define_get_tensor5() + self.define_get_tensor6() + self.define_tensor_array_gather() class Prelude: """Contains standard definitions.""" diff --git a/src/relay/pass/partial_eval.cc b/src/relay/pass/partial_eval.cc index 92f0db5d8ebe..84b51572f676 100644 --- a/src/relay/pass/partial_eval.cc +++ b/src/relay/pass/partial_eval.cc @@ -1008,7 +1008,7 @@ class PartialEvaluator : public ExprFunctor throw; } } - LOG(FATAL) << "No case Match"; + LOG(FATAL) << "No case Match for value " << op->data << "\n"; throw; }); } diff --git a/tests/python/frontend/tensorflow/test_forward.py b/tests/python/frontend/tensorflow/test_forward.py index e02532fa748b..67666d8a5efd 100644 --- a/tests/python/frontend/tensorflow/test_forward.py +++ b/tests/python/frontend/tensorflow/test_forward.py @@ -81,7 +81,7 @@ def vmobj_to_list(o): return hd elif o.constructor.name_hint == 'Nil': return [] - elif 'tensor_nil' in o.constructor.name_hint: + elif 'tensor_nil' in o.constructor.name_hin: return [0] elif 'tensor' in o.constructor.name_hint: return [o.fields[0].asnumpy()] @@ -680,6 +680,25 @@ def run(dtype_str): for dtype in tf_dtypes.keys(): run(dtype) +def test_tensor_array_read(): + def run(dtype_str, input_shape): + with tf.Graph().as_default(): + dtype = tf_dtypes[dtype_str] + data1 = np.random.choice([0, 1, 2, 3], size=input_shape) + data2 = np.random.choice([0, 1, 2, 3], size=input_shape) + t1 = tf.constant(data1.astype(dtype_str), dtype=dtype) + t2 = tf.constant(data2.astype(dtype_str), dtype=dtype) + ta1 = tf.TensorArray(dtype=dtype, size=2, infer_shape=False, dynamic_size=False) + ta2 = ta1.write(0, t1) + ta3 = ta2.write(1, t2) + out = ta3.read(0) + compare_tf_with_tvm([], [], 'TensorArrayReadV3:0', mode='debug') + for dtype in tf_dtypes.keys(): + run(dtype, (2, 2)) + run(dtype, (2, 2, 2)) + run(dtype, (2, 2, 2, 2)) + run(dtype, (2, 2, 2, 2, 2)) + run(dtype, (2, 2, 2, 2, 2, 2)) def test_tensor_array_scatter(): def run(dtype_str): @@ -693,26 +712,23 @@ def run(dtype_str): out0 = ta2.read(0) out1 = ta2.read(1) out2 = ta2.read(2) - g = tf.get_default_graph() compare_tf_with_tvm([], [], ['TensorArrayReadV3:0'], mode='debug') compare_tf_with_tvm([], [], ['TensorArrayReadV3_1:0'], mode='debug') compare_tf_with_tvm([], [], ['TensorArrayReadV3_2:0'], mode='debug') for dtype in tf_dtypes.keys(): run(dtype) -# TODO(wweic): Fix gather issue with PartialEvaluate -# def test_tensor_array_gather(): -# with tf.Graph().as_default(): -# dtype = 'float32' -# t = tf.constant([[1.0], [2.0], [3.0]]) -# scatter_indices = tf.constant([2, 1, 0]) -# gather_indices = tf.constant([1, 2]) -# ta1 = tf.TensorArray(dtype=tf.float32, size=3, infer_shape=False, dynamic_size=False) -# ta2 = ta1.scatter(scatter_indices, t) -# t1 = ta2.gather(gather_indices) -# g = tf.get_default_graph() -# compare_tf_with_tvm([], [], ['TensorArrayGatherV3:0'], mode='debug') - +def test_tensor_array_gather(): + with tf.Graph().as_default(): + dtype = 'float32' + t = tf.constant([[1.0], [2.0], [3.0]]) + scatter_indices = tf.constant([2, 1, 0]) + gather_indices = tf.constant([1, 2]) + ta1 = tf.TensorArray(dtype=tf.float32, size=3, infer_shape=False, dynamic_size=False) + ta2 = ta1.scatter(scatter_indices, t) + t1 = ta2.gather(gather_indices) + g = tf.get_default_graph() + compare_tf_with_tvm([], [], ['TensorArrayGatherV3:0'], mode='debug') def test_tensor_array_split(): def run(dtype_str): diff --git a/tests/python/relay/test_pass_partial_eval.py b/tests/python/relay/test_pass_partial_eval.py index cf4f8f6cee74..e6215181eace 100644 --- a/tests/python/relay/test_pass_partial_eval.py +++ b/tests/python/relay/test_pass_partial_eval.py @@ -50,8 +50,9 @@ def tipe(expr): transform.InferType()]) -def dcpe(expr, mod=None, grad=False): - passes = [transform.PartialEvaluate(), +def dcpe(expr, mod=None, entry_funcs=None, grad=False): + passes = [transform.RemoveUnusedFunctions(entry_funcs), + transform.PartialEvaluate(), transform.DeadCodeElimination(inline_once=True)] if grad: expr = gradient(run_infer_type(expr)) @@ -195,7 +196,7 @@ def test_map(): mod["main"] = expected expected = mod["main"] orig = Function([], orig) - res = dcpe(orig, mod=mod) + res = dcpe(orig, mod=mod, entry_funcs=['f', 'main']) assert alpha_equal(res.body, expected.body) @@ -328,9 +329,9 @@ def test_nat_update(): p = Prelude(m) add_nat_definitions(p) m = transform.ToANormalForm()(m) + m = transform.RemoveUnusedFunctions([])(m) transform.PartialEvaluate()(m) - def test_tuple_match(): a = relay.Var("a") b = relay.Var("b")