Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
96 changes: 26 additions & 70 deletions python/tvm/relay/frontend/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,12 +132,16 @@ def _impl(inputs, input_types):
return get_relay_op(name)(data0, data1)
return _impl

def _abs():

def _unary(name):
def _impl(inputs, input_types):
data = inputs[0]
return _op.abs(data)
input_type = input_types[0]
data = _convert_elemwise_input(inputs[0], input_type)

return get_relay_op(name)(data)
return _impl


def _arange():
def _impl(inputs, input_types):
if len(inputs) == 5:
Expand Down Expand Up @@ -1254,26 +1258,6 @@ def _impl(inputs, input_types):
return _op.nn.pad(data, pad_width, pad_value)
return _impl

def _sqrt():
def _impl(inputs, input_types):
data = inputs[0]
return _op.tensor.sqrt(data)
return _impl


def _rsqrt():
def _impl(inputs, input_types):
data = inputs[0]
return _op.tensor.rsqrt(data)
return _impl


def _ceil():
def _impl(inputs, input_types):
data = inputs[0]
return _op.ceil(data)
return _impl


def _clamp():
def _impl(inputs, input_types):
Expand All @@ -1284,20 +1268,6 @@ def _impl(inputs, input_types):
return _impl


def _floor():
def _impl(inputs, input_types):
data = inputs[0]
return _op.floor(data)
return _impl


def _round():
def _impl(inputs, input_types):
data = inputs[0]
return _op.round(data)
return _impl


def _to():
def _impl(inputs, input_types):
data = inputs[0]
Expand Down Expand Up @@ -1375,17 +1345,6 @@ def _impl(inputs, input_types):
return inputs[0]
return _impl

def _neg():
def _impl(inputs, input_types):
data = inputs[0]
return _op.tensor.negative(data)
return _impl

def _tanh():
def _impl(inputs, input_types):
data = inputs[0]
return _op.tensor.tanh(data)
return _impl

def _Bool():
def _impl(inputs, input_types):
Expand Down Expand Up @@ -1467,18 +1426,6 @@ def _impl(inputs, input_types):
return _impl


def _isfinite():
def _impl(inputs, input_types):
return _op.isfinite(inputs[0])
return _impl


def _isnan():
def _impl(inputs, input_types):
return _op.isnan(inputs[0])
return _impl


def _list_getitem(prelude):
def _impl(inputs, input_types):
return prelude.nth(inputs[0], _wrap_const(inputs[1]))
Expand Down Expand Up @@ -1601,7 +1548,6 @@ def _get_convert_map(prelude):
"aten::mul" : _elemwise("multiply"),
"aten::mul_" : _elemwise("multiply"),
"aten::pow" : _elemwise("power"),
"aten::abs" : _abs(),
"aten::arange" : _arange(),
"aten::div" : _elemwise("divide"),
"aten::div_" : _elemwise("divide"),
Expand Down Expand Up @@ -1683,12 +1629,26 @@ def _get_convert_map(prelude):
"aten::argmax" : _reduce("argmax"),
"aten::std" : _std(),
"aten::var" : _variance(),
"aten::sqrt" : _sqrt(),
"aten::rsqrt" : _rsqrt(),
"aten::ceil" : _ceil(),
"aten::abs" : _unary("abs"),
"aten::neg" : _unary("negative"),
"aten::cos" : _unary("cos"),
"aten::sin" : _unary("sin"),
"aten::tan" : _unary("tan"),
"aten::tanh" : _unary("tanh"),
"aten::atan" : _unary("atan"),
"aten::log" : _unary("log"),
"aten::exp" : _unary("exp"),
"aten::erf" : _unary("erf"),
"aten::trunc" : _unary("trunc"),
"aten::sign" : _unary("sign"),
"aten::sqrt" : _unary("sqrt"),
"aten::rsqrt" : _unary("rsqrt"),
"aten::ceil" : _unary("ceil"),
"aten::floor" : _unary("floor"),
"aten::round" : _unary("round"),
"aten::isfinite" : _unary("isfinite"),
"aten::isnan" : _unary("isnan"),
"aten::clamp" : _clamp(),
"aten::floor" : _floor(),
"aten::round" : _round(),
"aten::detach" : _identity(),
"aten::upsample_bilinear2d" : _upsample("bilinear"),
"aten::upsample_nearest2d" : _upsample("nearest_neighbor"),
Expand All @@ -1703,12 +1663,8 @@ def _get_convert_map(prelude):
"aten::logical_xor" : _logical_xor(),
"aten::bitwise_not" : _bitwise_not(),
"aten::bitwise_xor" : _bitwise_xor(),
"aten::isfinite" : _isfinite(),
"aten::isnan" : _isnan(),
"aten::Bool" : _Bool(),
"aten::Float" : _Float(),
"aten::neg" : _neg(),
"aten::tanh" : _tanh(),
"aten::adaptive_avg_pool3d" : _adaptive_avg_pool_3d(),
"aten::adaptive_max_pool3d" : _adaptive_max_pool_3d(),
"aten::mm" : _matmul(),
Expand Down
141 changes: 88 additions & 53 deletions tests/python/frontend/pytorch/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -1497,30 +1497,6 @@ def forward(self, *args):
verify_model(IsInf1().float().eval(), input_data=input_data)


def test_forward_rsqrt():
torch.set_grad_enabled(False)
input_shape = [1, 3, 10, 10]

class Rsqrt1(Module):
def forward(self, *args):
return torch.rsqrt(args[0])

input_data = torch.rand(input_shape).float()
verify_model(Rsqrt1().float().eval(), input_data=input_data)


def test_forward_ceil():
torch.set_grad_enabled(False)
input_shape = [1, 3, 10, 10]

class Ceil1(Module):
def forward(self, *args):
return torch.ceil(args[0])

input_data = torch.rand(input_shape).float()
verify_model(Ceil1().float().eval(), input_data=input_data)


def test_forward_clamp():
torch.set_grad_enabled(False)
input_shape = [1, 3, 10, 10]
Expand All @@ -1543,30 +1519,6 @@ def forward(self, *args):
verify_model(Clamp3().float().eval(), input_data=input_data)


def test_forward_floor():
torch.set_grad_enabled(False)
input_shape = [1, 3, 10, 10]

class Floor1(Module):
def forward(self, *args):
return torch.floor(args[0])

input_data = torch.rand(input_shape).float()
verify_model(Floor1().float().eval(), input_data=input_data)


def test_forward_round():
torch.set_grad_enabled(False)
input_shape = [1, 3, 10, 10]

class Round1(Module):
def forward(self, *args):
return torch.round(args[0])

input_data = torch.rand(input_shape).float()
verify_model(Round1().float().eval(), input_data=input_data)


def test_forward_ones():
torch.set_grad_enabled(False)

Expand Down Expand Up @@ -1849,6 +1801,93 @@ def forward(self, *args):
verify_model(LogicalXor2().float().eval(), input_data=[lhs])


def test_forward_unary():
torch.set_grad_enabled(False)

class Sqrt1(Module):
def forward(self, *args):
return torch.sqrt(args[0])

class RSqrt1(Module):
def forward(self, *args):
return torch.rsqrt(args[0])

class Ceil1(Module):
def forward(self, *args):
return torch.ceil(args[0])

class Floor1(Module):
def forward(self, *args):
return torch.floor(args[0])

class Round1(Module):
def forward(self, *args):
return torch.round(args[0])

class Cos1(Module):
def forward(self, *args):
return torch.cos(args[0])

class Sin1(Module):
def forward(self, *args):
return torch.sin(args[0])

class Tan1(Module):
def forward(self, *args):
return torch.tan(args[0])

class Tanh1(Module):
def forward(self, *args):
return torch.tanh(args[0])

class ATanh1(Module):
def forward(self, *args):
return torch.atan(args[0])

class Log1(Module):
def forward(self, *args):
return torch.log(args[0])

class Exp1(Module):
def forward(self, *args):
return torch.exp(args[0])

class Erf1(Module):
def forward(self, *args):
return torch.erf(args[0])

class Trunc1(Module):
def forward(self, *args):
return torch.trunc(args[0])

class Sign1(Module):
def forward(self, *args):
return torch.sign(args[0])

class Neg1(Module):
def forward(self, *args):
return torch.neg(args[0])

input_shape = [1, 3, 10, 10]
input_data = torch.rand(input_shape).float()
verify_model(Sqrt1().float().eval(), input_data=input_data)
verify_model(RSqrt1().float().eval(), input_data=input_data)
verify_model(Ceil1().float().eval(), input_data=input_data)
verify_model(Floor1().float().eval(), input_data=input_data)
verify_model(Round1().float().eval(), input_data=input_data)
verify_model(Cos1().float().eval(), input_data=input_data)
verify_model(Sin1().float().eval(), input_data=input_data)
verify_model(Tan1().float().eval(), input_data=input_data)
verify_model(Tanh1().float().eval(), input_data=input_data)
verify_model(ATanh1().float().eval(), input_data=input_data)
verify_model(Log1().float().eval(), input_data=input_data)
verify_model(Exp1().float().eval(), input_data=input_data)
verify_model(Erf1().float().eval(), input_data=input_data)
verify_model(Trunc1().float().eval(), input_data=input_data)
verify_model(Sign1().float().eval(), input_data=input_data)
verify_model(Neg1().float().eval(), input_data=input_data)


if __name__ == "__main__":
# Single operator tests
test_forward_add()
Expand Down Expand Up @@ -1907,12 +1946,8 @@ def forward(self, *args):
test_forward_mean()
test_forward_expand()
test_forward_pow()
test_forward_abs()
test_forward_rsqrt()
test_forward_ceil()
test_forward_unary()
test_forward_clamp()
test_forward_floor()
test_forward_round()
test_forward_logical_not()
test_forward_bitwise_not()
test_forward_bitwise_xor()
Expand Down