Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
104 changes: 96 additions & 8 deletions python/tvm/relay/frontend/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,23 @@ def _need_module_for_shape_inference(op):
def _need_prelude_for_shape_inference(op):
return "TensorArray" in op

def _extract_tensor_data(prelude, rank, dtype, source):
get_tensor_name = 'get_tensor{}'.format(rank)
get_tensor_func = prelude.get_var(get_tensor_name, dtype)
return get_tensor_func(source)

def _get_tensor_shape(tensor, mod):
if isinstance(tensor, tvm.relay.expr.Var):
return tensor.type_annotation.shape
# TODO: Run infer_type to get the tensor rank
raise Exception("TODO")

def _get_tensor_rank(tensor, mod):
if isinstance(tensor, tvm.relay.expr.Var):
return len(tensor.type_annotation.shape)
# TODO: Run infer_type to get the tensor rank
raise Exception("TODO")

def _rsqrt():
def _impl(inputs, attr, params):
inputs.append(tvm.relay.const(-0.5, attr['T'].name))
Expand Down Expand Up @@ -538,17 +555,31 @@ def _impl(inputs, attr, params, prelude):
def _tensor_array_scatter():
def _impl(inputs, attr, params, prelude):
dtype_str = attr.get('T').name
values_rank = len(inputs[2].type_annotation.shape)
values_shape = _get_tensor_shape(inputs[2], prelude.mod)
values_rank = len(values_shape)
unstack_name = "tensor_array_unstack_tensor{}".format(values_rank)
unstack_function = prelude.get_var(unstack_name, dtype_str)
values = unstack_function(inputs[2])

tensor_array_scatter_func = prelude.get_var('tensor_array_scatter', dtype_str)
return tensor_array_scatter_func(inputs[0], inputs[1], values)
result = tensor_array_scatter_func(inputs[0], inputs[1], values)
_tensor_array_shape_tracker.trace(result, tuple(values_shape[1:]))
return result
return _impl

def _tensor_array_gather():
def _impl(inputs, attr, params, prelude):
return prelude.tensor_array_gather(inputs[2], inputs[1])
dtype_str = attr.get('dtype').name
tensor_array_gather_func = prelude.get_var('tensor_array_gather', dtype_str)
result = tensor_array_gather_func(inputs[2], inputs[1])

shape = _tensor_array_shape_tracker.get_shape(inputs[2])
if shape is not None:
return _extract_tensor_data(prelude,
len(shape) + 1,
dtype_str,
result)
return result
return _impl

def _tensor_array_size():
Expand All @@ -558,31 +589,51 @@ def _impl(inputs, attr, params, prelude):

def _tensor_array_write():
def _impl(inputs, attr, params, prelude):
input_rank = len(inputs[2].type_annotation.shape)
input_shape = inputs[2].type_annotation.shape
input_rank = len(input_shape)
dtype = attr.get('T').name

tensor_name = 'tensor{}'.format(input_rank)
tensor_func = prelude.get_var(tensor_name, dtype)
v = tensor_func(inputs[2])
write_func = prelude.get_var('tensor_array_write', dtype)

return write_func(inputs[3], _op.take(inputs[1], tvm.relay.const(0)), v)
res = write_func(inputs[3], _op.take(inputs[1], tvm.relay.const(0)), v)
_tensor_array_shape_tracker.trace(inputs[3], input_shape)
_tensor_array_shape_tracker.union(inputs[3], res)
return res
return _impl

def _tensor_array_read():
def _impl(inputs, attr, params, prelude):
read_func = prelude.get_var('tensor_array_read', attr.get('dtype').name)
rank = len(_tensor_array_shape_tracker.get_shape(inputs[2]))
if rank is not None:
# Optimize for the case that all the tensors in the tensor array are of the same rank,
# we can emit code to extract the tensor out of the ADT object at compile time.
source = read_func(inputs[2], _op.take(inputs[1], tvm.relay.const(0)))
return _extract_tensor_data(prelude,
rank,
attr['dtype'].name,
source)
# The tensor array contains various ranks of tensor, it should be the rare case
return read_func(inputs[2], _op.take(inputs[1], tvm.relay.const(0)))
return _impl

def _tensor_array_split():
def _impl(inputs, attr, params, prelude):
input_rank = len(inputs[1].type_annotation.shape)
input_shape = _get_tensor_shape(inputs[1], prelude.mod)
input_rank = len(input_shape)
dtype_str = attr.get('T').name
v = prelude.get_var("tensor{}".format(input_rank), dtype_str)(inputs[1])
lengths = _op.cast(inputs[2], 'int32')
split_var = prelude.get_var('tensor_array_split', dtype_str)
return split_var(inputs[0], v, lengths)
split_func = prelude.get_var('tensor_array_split', dtype_str)
result = split_func(inputs[0], v, lengths)

# TODO: How can we calculate the result shape here
result_shape = [None] + input_shape[1:]
_tensor_array_shape_tracker.trace(result, tuple(result_shape))
return result
return _impl

def _tensor_array_concat():
Expand Down Expand Up @@ -1371,6 +1422,43 @@ def _impl(inputs, attr, params):
return _res
return _impl

class TensorArrayShapeTracker():
def __init__(self):
self.expr_shapes_map_ = defaultdict(set)
self.expr_union_find = {}

def find(self, expr):
if expr not in self.expr_union_find:
self.expr_union_find[expr] = expr
while self.expr_union_find[expr] != expr:
expr = self.expr_union_find[expr]
return expr

def union(self, a, b):
group_a = self.find(a)
group_b = self.find(b)
if group_a != group_b:
self.expr_union_find[group_a] = group_b
self.expr_shapes_map_[group_b] = self.expr_shapes_map_[group_b].union(self.expr_shapes_map_[group_a])

def trace(self, expr, shape):
self.expr_shapes_map_[expr].add(shape)

def get_shape(self, expr):
shapes = self.expr_shapes_map_[self.find(expr)]
# TODO: Fix Hack, use string to dedup shapes
shapes_str = set(str(x) for x in shapes)
if shapes is not None:
if len(shapes_str) == 1:
return list(set(shapes))[0]
else:
print("shapes {}".format(shapes))
raise None
return None

# Remember the rank of tensors stored in each tensor array
_tensor_array_ranks = defaultdict(set)
_tensor_array_shape_tracker = TensorArrayShapeTracker()

# compatible operators that do NOT require any conversion.
_identity_list = []
Expand Down
138 changes: 136 additions & 2 deletions python/tvm/relay/prelude.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,134 @@ def define_tensor_take(self):
tensor6_case], False),
tensor_t(), [])

def define_get_tensor0(self):
"""Defines a function to extract a rank 0 tensor out of tensor_t.
get_tensor0(t): tensor_t -> Tensor[(), dtype]
"""
get_tensor0_name = self.get_name("get_tensor0")
get_tensor0_var = GlobalVar(get_tensor0_name)
setattr(self.prelude, get_tensor0_name, get_tensor0_var)
tensor_t = self.get_var('tensor_t')
tensor_var = self.get_var('tensor0')
t = Var('tensor', tensor_t())
tvar = Var('t')
case =\
Clause(PatternConstructor(tensor_var, [PatternVar(tvar)]), tvar)
self.prelude.mod[get_tensor0_var] =\
Function([t],
Match(t, [case], False),
TensorType([], self.dtype), [])

def define_get_tensor1(self):
"""Defines a function to extract a rank 1 tensor out of tensor_t.
get_tensor1(t): tensor_t -> Tensor[(Any()), dtype]
"""
get_tensor1_name = self.get_name("get_tensor1")
get_tensor1_var = GlobalVar(get_tensor1_name)
setattr(self.prelude, get_tensor1_name, get_tensor1_var)
tensor_t = self.get_var('tensor_t')
tensor_var = self.get_var('tensor1')
t = Var('tensor', tensor_t())
tvar = Var('t')
case =\
Clause(PatternConstructor(tensor_var, [PatternVar(tvar)]), tvar)
self.prelude.mod[get_tensor1_var] =\
Function([t],
Match(t, [case], False),
TensorType([Any()], self.dtype), [])

def define_get_tensor2(self):
"""Defines a function to extract a rank 2 tensor out of tensor_t.
get_tensor2(t): tensor_t -> Tensor[(Any(), Any()), dtype]
"""
get_tensor2_name = self.get_name("get_tensor2")
get_tensor2_var = GlobalVar(get_tensor2_name)
setattr(self.prelude, get_tensor2_name, get_tensor2_var)
tensor_t = self.get_var('tensor_t')
tensor_var = self.get_var('tensor2')
t = Var('tensor', tensor_t())
tvar = Var('t')
case =\
Clause(PatternConstructor(tensor_var, [PatternVar(tvar)]), tvar)
self.prelude.mod[get_tensor2_var] =\
Function([t],
Match(t, [case], False),
TensorType([Any(), Any()], self.dtype), [])

def define_get_tensor3(self):
"""Defines a function to extract a rank 3 tensor out of tensor_t.
get_tensor3(t): tensor_t -> Tensor[(Any(), Any(), Any()), dtype]
"""
get_tensor3_name = self.get_name("get_tensor3")
get_tensor3_var = GlobalVar(get_tensor3_name)
setattr(self.prelude, get_tensor3_name, get_tensor3_var)
tensor_t = self.get_var('tensor_t')
tensor_var = self.get_var('tensor3')
t = Var('tensor', tensor_t())
tvar = Var('t')
case =\
Clause(PatternConstructor(tensor_var, [PatternVar(tvar)]), tvar)
self.prelude.mod[get_tensor3_var] =\
Function([t],
Match(t, [case], False),
TensorType([Any(), Any(), Any()], self.dtype), [])

def define_get_tensor4(self):
"""Defines a function to extract a rank 4 tensor out of tensor_t.
get_tensor4(t): tensor_t -> Tensor[(Any(), Any(), Any(), Any()), dtype]
"""
get_tensor4_name = self.get_name("get_tensor4")
get_tensor4_var = GlobalVar(get_tensor4_name)
setattr(self.prelude, get_tensor4_name, get_tensor4_var)
tensor_t = self.get_var('tensor_t')
tensor_var = self.get_var('tensor4')
t = Var('tensor', tensor_t())
tvar = Var('t')
case =\
Clause(PatternConstructor(tensor_var, [PatternVar(tvar)]), tvar)
self.prelude.mod[get_tensor4_var] =\
Function([t],
Match(t, [case], False),
TensorType([Any(), Any(), Any(), Any()], self.dtype), [])

def define_get_tensor5(self):
"""Defines a function to extract a rank 5 tensor out of tensor_t.
get_tensor5(t):
tensor_t -> Tensor[(Any(), Any(), Any(), Any(), Any()), dtype]
"""
get_tensor5_name = self.get_name("get_tensor5")
get_tensor5_var = GlobalVar(get_tensor5_name)
setattr(self.prelude, get_tensor5_name, get_tensor5_var)
tensor_t = self.get_var('tensor_t')
tensor_var = self.get_var('tensor5')
t = Var('tensor', tensor_t())
tvar = Var('t')
case =\
Clause(PatternConstructor(tensor_var, [PatternVar(tvar)]), tvar)
self.prelude.mod[get_tensor5_var] =\
Function([t],
Match(t, [case], False),
TensorType([Any(), Any(), Any(), Any(), Any()], self.dtype), [])

def define_get_tensor6(self):
"""Defines a function to extract a rank 6 tensor out of tensor_t.
get_tensor6(t):
tensor_t -> Tensor[(Any(), Any(), Any(), Any(), Any(), Any()), dtype]
"""
get_tensor6_name = self.get_name("get_tensor6")
get_tensor6_var = GlobalVar(get_tensor6_name)
setattr(self.prelude, get_tensor6_name, get_tensor6_var)
tensor_t = self.get_var('tensor_t')
tensor_var = self.get_var('tensor6')
t = Var('tensor', tensor_t())
tvar = Var('t')
case =\
Clause(PatternConstructor(tensor_var, [PatternVar(tvar)]), tvar)
self.prelude.mod[get_tensor6_var] =\
Function([t],
Match(t, [case], False),
TensorType([Any(), Any(), Any(), Any(), Any(), Any()], self.dtype), [])

def define_tensor_expand_dims(self):
"""Defines a function to grow a tensor_t's rank by adding one dimension in front
of the original tensor_t.
Expand Down Expand Up @@ -640,8 +768,14 @@ def register(self):
self.define_tensor_array_split()
self.define_tensor_array_concat()
self.define_tensor_array_stack()
# TODO(wweic): Gather fails in PartialEvaluate
# self.define_tensor_array_gather()
self.define_get_tensor0()
self.define_get_tensor1()
self.define_get_tensor2()
self.define_get_tensor3()
self.define_get_tensor4()
self.define_get_tensor5()
self.define_get_tensor6()
self.define_tensor_array_gather()

class Prelude:
"""Contains standard definitions."""
Expand Down
2 changes: 1 addition & 1 deletion src/relay/pass/partial_eval.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1008,7 +1008,7 @@ class PartialEvaluator : public ExprFunctor<PStatic(const Expr& e, LetList* ll)>
throw;
}
}
LOG(FATAL) << "No case Match";
LOG(FATAL) << "No case Match for value " << op->data << "\n";
throw;
});
}
Expand Down
46 changes: 31 additions & 15 deletions tests/python/frontend/tensorflow/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ def vmobj_to_list(o):
return hd
elif o.constructor.name_hint == 'Nil':
return []
elif 'tensor_nil' in o.constructor.name_hint:
elif 'tensor_nil' in o.constructor.name_hin:
return [0]
elif 'tensor' in o.constructor.name_hint:
return [o.fields[0].asnumpy()]
Expand Down Expand Up @@ -680,6 +680,25 @@ def run(dtype_str):
for dtype in tf_dtypes.keys():
run(dtype)

def test_tensor_array_read():
def run(dtype_str, input_shape):
with tf.Graph().as_default():
dtype = tf_dtypes[dtype_str]
data1 = np.random.choice([0, 1, 2, 3], size=input_shape)
data2 = np.random.choice([0, 1, 2, 3], size=input_shape)
t1 = tf.constant(data1.astype(dtype_str), dtype=dtype)
t2 = tf.constant(data2.astype(dtype_str), dtype=dtype)
ta1 = tf.TensorArray(dtype=dtype, size=2, infer_shape=False, dynamic_size=False)
ta2 = ta1.write(0, t1)
ta3 = ta2.write(1, t2)
out = ta3.read(0)
compare_tf_with_tvm([], [], 'TensorArrayReadV3:0', mode='debug')
for dtype in tf_dtypes.keys():
run(dtype, (2, 2))
run(dtype, (2, 2, 2))
run(dtype, (2, 2, 2, 2))
run(dtype, (2, 2, 2, 2, 2))
run(dtype, (2, 2, 2, 2, 2, 2))

def test_tensor_array_scatter():
def run(dtype_str):
Expand All @@ -693,26 +712,23 @@ def run(dtype_str):
out0 = ta2.read(0)
out1 = ta2.read(1)
out2 = ta2.read(2)
g = tf.get_default_graph()
compare_tf_with_tvm([], [], ['TensorArrayReadV3:0'], mode='debug')
compare_tf_with_tvm([], [], ['TensorArrayReadV3_1:0'], mode='debug')
compare_tf_with_tvm([], [], ['TensorArrayReadV3_2:0'], mode='debug')
for dtype in tf_dtypes.keys():
run(dtype)

# TODO(wweic): Fix gather issue with PartialEvaluate
# def test_tensor_array_gather():
# with tf.Graph().as_default():
# dtype = 'float32'
# t = tf.constant([[1.0], [2.0], [3.0]])
# scatter_indices = tf.constant([2, 1, 0])
# gather_indices = tf.constant([1, 2])
# ta1 = tf.TensorArray(dtype=tf.float32, size=3, infer_shape=False, dynamic_size=False)
# ta2 = ta1.scatter(scatter_indices, t)
# t1 = ta2.gather(gather_indices)
# g = tf.get_default_graph()
# compare_tf_with_tvm([], [], ['TensorArrayGatherV3:0'], mode='debug')

def test_tensor_array_gather():
with tf.Graph().as_default():
dtype = 'float32'
t = tf.constant([[1.0], [2.0], [3.0]])
scatter_indices = tf.constant([2, 1, 0])
gather_indices = tf.constant([1, 2])
ta1 = tf.TensorArray(dtype=tf.float32, size=3, infer_shape=False, dynamic_size=False)
ta2 = ta1.scatter(scatter_indices, t)
t1 = ta2.gather(gather_indices)
g = tf.get_default_graph()
compare_tf_with_tvm([], [], ['TensorArrayGatherV3:0'], mode='debug')

def test_tensor_array_split():
def run(dtype_str):
Expand Down
Loading