diff --git a/python/tvm/dlight/gpu/general_reduction.py b/python/tvm/dlight/gpu/general_reduction.py index d3979ce0e4c3..b1564bf61fa9 100644 --- a/python/tvm/dlight/gpu/general_reduction.py +++ b/python/tvm/dlight/gpu/general_reduction.py @@ -61,6 +61,23 @@ def apply( # pylint: disable=too-many-locals # Align the number of block iters of the last block. num_last_block_iter = len(block_infos[-1].dom_kind()) if num_last_block_iter < len(dom_kind): + # If the last block is a scalar value, there is nothing left to + # tile/parallelise, and `iters` is an empty tuple. + # Add a unit thread loop so the final write happens inside a valid + # GPU thread environment. + if num_last_block_iter == 0: + # Put every block (both the running reductions and the final + # scalar write) inside a trivial GPU thread. The very first block + # gets a `blockIdx.x` wrapper so that kernels still have a unique + # block scope. + for i, info in enumerate(block_infos): + loop_rv = sch.add_unit_loop(info.block_rv) + if i == 0: + sch.bind(loop_rv, "blockIdx.x") + else: + sch.bind(loop_rv, "threadIdx.x") + + return sch def f_layout_mapping(*iters): analyzer = arith.Analyzer() diff --git a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py index 48869767ad66..d30b0afc4813 100644 --- a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py +++ b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py @@ -776,6 +776,25 @@ def _conv3d(self, node: fx.Node) -> relax.Var: groups=groups, ) + def _cross_entropy_loss( + self, + preds: relax.Expr, + targets: relax.Expr, + weights: Optional[relax.Expr], + reduction: str, + ignore_index: int, + ) -> relax.Expr: + log_probs = relax.op.nn.log_softmax(preds) + return self.block_builder.emit( + relax.op.nn.nll_loss( + log_probs, + targets, + weights, + reduction, + ignore_index, + ) + ) + def _einsum(self, node: fx.Node) -> relax.Var: import torch # type: ignore diff --git a/python/tvm/relax/frontend/torch/exported_program_translator.py b/python/tvm/relax/frontend/torch/exported_program_translator.py index df532fd1ea04..631677d609b9 100644 --- a/python/tvm/relax/frontend/torch/exported_program_translator.py +++ b/python/tvm/relax/frontend/torch/exported_program_translator.py @@ -66,7 +66,7 @@ def _reciprocal(self, node: fx.Node) -> relax.Var: ########## Neural Network ########## - def _batch_norm(self, node: fx.Node, training) -> relax.Var: + def _batch_norm(self, node: fx.Node, training: bool) -> relax.Var: import numpy as np x = self.env[node.args[0]] @@ -113,6 +113,14 @@ def _batch_norm_legit_no_training(self, node: fx.Node) -> relax.Var: training = False return self._batch_norm(node, training) + def _cross_entropy_default(self, node: fx.Node) -> relax.Expr: + preds = self.env[node.args[0]] + targets = self.env[node.args[1]] + weight = self.env.get(node.args[2], None) if len(node.args) > 2 else None + reduction = node.kwargs.get("reduction", "mean") + ignore_index = node.kwargs.get("ignore_index", -100) + return self._cross_entropy_loss(preds, targets, weight, reduction, ignore_index) + def _group_norm(self, node: fx.Node) -> relax.Var: x = self.env[node.args[0]] num_groups = node.args[1] @@ -399,6 +407,7 @@ def create_convert_map( "conv1d.default": self._conv1d, "conv2d.default": self._conv2d, "conv3d.default": self._conv3d, + "cross_entropy_loss.default": self._cross_entropy_default, "einsum.default": self._einsum, "embedding.default": lambda node: self._embedding_impl( self.env[node.args[1]], self.env[node.args[0]] diff --git a/python/tvm/relax/frontend/torch/fx_translator.py b/python/tvm/relax/frontend/torch/fx_translator.py index 5f65f86a4303..bdbf7aba6aed 100644 --- a/python/tvm/relax/frontend/torch/fx_translator.py +++ b/python/tvm/relax/frontend/torch/fx_translator.py @@ -308,12 +308,7 @@ def _cross_entropy(self, node: fx.Node) -> relax.Expr: weights = self.env.get(node.kwargs["weight"], None) reduction = node.kwargs["reduction"] ignore_index = node.kwargs["ignore_index"] - - return self.block_builder.emit( - relax.op.nn.nll_loss( - relax.op.nn.log_softmax(preds), targets, weights, reduction, ignore_index - ) - ) + return self._cross_entropy_loss(preds, targets, weights, reduction, ignore_index) def _cross_entropy_module(self, node: fx.Node) -> relax.Expr: preds = self.env[node.args[0]] @@ -330,10 +325,12 @@ def _cross_entropy_module(self, node: fx.Node) -> relax.Expr: reduction = module.reduction ignore_index = module.ignore_index - return self.block_builder.emit( - relax.op.nn.nll_loss( - relax.op.nn.log_softmax(preds), targets, weights, reduction, ignore_index - ) + return self._cross_entropy_loss( + preds, + targets, + weights, + reduction, + ignore_index, ) def _embedding_module(self, node: fx.Node) -> relax.Var: diff --git a/tests/python/relax/test_from_exported_to_cuda.py b/tests/python/nightly/test_nnapi/test_from_exported_to_cuda.py similarity index 70% rename from tests/python/relax/test_from_exported_to_cuda.py rename to tests/python/nightly/test_nnapi/test_from_exported_to_cuda.py index 6bb35b50b1df..3f0964cfa8ed 100644 --- a/tests/python/relax/test_from_exported_to_cuda.py +++ b/tests/python/nightly/test_nnapi/test_from_exported_to_cuda.py @@ -21,6 +21,7 @@ import numpy as np import torch from torch import nn +from torch.nn import functional as F from torch.export import export from tvm.relax.frontend.torch import from_exported_program from torch.nn import Softmax, Upsample @@ -742,5 +743,332 @@ def forward(self, x): assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, target, dev) +@tvm.testing.parametrize_targets("cuda") +def test_leakyrelu_module(target, dev): + class LeakyReLUModule(nn.Module): + def __init__(self): + super().__init__() + self.act = nn.LeakyReLU(negative_slope=0.1) + + def forward(self, x): + return self.act(x) + + raw_data = np.random.randn(2, 3).astype(np.float32) + torch_module = LeakyReLUModule().eval() + assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, target, dev) + + +@tvm.testing.parametrize_targets("cuda") +def test_log_softmax_module(target, dev): + class LogSoftmaxModule(nn.Module): + def __init__(self): + super().__init__() + self.logsoftmax = nn.LogSoftmax(dim=1) + + def forward(self, x): + return self.logsoftmax(x) + + raw_data = np.random.randn(4, 5).astype(np.float32) + torch_module = LogSoftmaxModule().eval() + assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, target, dev) + + +@tvm.testing.parametrize_targets("cuda") +def test_softmax_module(target, dev): + class SoftmaxModule(nn.Module): + def __init__(self): + super().__init__() + self.softmax = nn.Softmax(dim=1) + + def forward(self, x): + return self.softmax(x) + + raw_data = np.random.randn(4, 5).astype(np.float32) + torch_module = SoftmaxModule().eval() + assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, target, dev) + + +@tvm.testing.parametrize_targets("cuda") +def test_adaptive_avg_pool2d_module(target, dev): + class AdaptiveAvgPool2dModule(nn.Module): + def __init__(self): + super().__init__() + self.pool = nn.AdaptiveAvgPool2d((1, 1)) + + def forward(self, x): + return self.pool(x) + + raw_data = np.random.randn(2, 3, 8, 8).astype(np.float32) + torch_module = AdaptiveAvgPool2dModule().eval() + assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, target, dev) + + +@tvm.testing.parametrize_targets("cuda") +def test_avg_pool2d_module(target, dev): + class AvgPool2dModule(nn.Module): + def __init__(self): + super().__init__() + self.pool = nn.AvgPool2d(kernel_size=2) + + def forward(self, x): + return self.pool(x) + + raw_data = np.random.randn(2, 3, 8, 8).astype(np.float32) + torch_module = AvgPool2dModule().eval() + assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, target, dev) + + +@tvm.testing.parametrize_targets("cuda") +def test_conv1d_module(target, dev): + class Conv1dModule(nn.Module): + def __init__(self): + super().__init__() + self.conv = nn.Conv1d(in_channels=3, out_channels=4, kernel_size=3) + + def forward(self, x): + return self.conv(x) + + raw_data = np.random.randn(2, 3, 10).astype(np.float32) + torch_module = Conv1dModule().eval() + assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, target, dev) + + +@tvm.testing.parametrize_targets("cuda") +def test_conv2d_module(target, dev): + class Conv2dModule(nn.Module): + def __init__(self): + super().__init__() + self.conv = nn.Conv2d(in_channels=3, out_channels=4, kernel_size=3) + + def forward(self, x): + return self.conv(x) + + raw_data = np.random.randn(2, 3, 10, 10).astype(np.float32) + torch_module = Conv2dModule().eval() + assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, target, dev) + + +@tvm.testing.parametrize_targets("cuda") +def test_conv3d_module(target, dev): + class Conv3dModule(nn.Module): + def __init__(self): + super().__init__() + self.conv = nn.Conv3d(in_channels=2, out_channels=3, kernel_size=3) + + def forward(self, x): + return self.conv(x) + + raw_data = np.random.randn(1, 2, 8, 8, 8).astype(np.float32) + torch_module = Conv3dModule().eval() + assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, target, dev) + + +@tvm.testing.parametrize_targets("cuda") +def test_group_norm_module(target, dev): + class GroupNormModule(nn.Module): + def __init__(self): + super().__init__() + self.gn = nn.GroupNorm(num_groups=1, num_channels=4) + + def forward(self, x): + return self.gn(x) + + raw_data = np.random.randn(2, 4, 8, 8).astype(np.float32) + torch_module = GroupNormModule().eval() + assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, target, dev) + + +@tvm.testing.parametrize_targets("cuda") +def test_layer_norm_module(target, dev): + class LayerNormModule(nn.Module): + def __init__(self): + super().__init__() + self.ln = nn.LayerNorm(normalized_shape=8) + + def forward(self, x): + return self.ln(x) + + raw_data = np.random.randn(2, 4, 8).astype(np.float32) + torch_module = LayerNormModule().eval() + assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, target, dev) + + +@tvm.testing.parametrize_targets("cuda") +def test_linear_module(target, dev): + class LinearModule(nn.Module): + def __init__(self): + super().__init__() + self.linear = nn.Linear(10, 5) + + def forward(self, x): + return self.linear(x) + + raw_data = np.random.randn(4, 10).astype(np.float32) + torch_module = LinearModule().eval() + assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, target, dev) + + +@tvm.testing.parametrize_targets("cuda") +def test_max_pool2d_module(target, dev): + class MaxPool2dModule(nn.Module): + def __init__(self): + super().__init__() + self.pool = nn.MaxPool2d(kernel_size=2) + + def forward(self, x): + return self.pool(x) + + raw_data = np.random.randn(2, 3, 8, 8).astype(np.float32) + torch_module = MaxPool2dModule().eval() + assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, target, dev) + + +@tvm.testing.parametrize_targets("cuda") +def test_embedding_module(target, dev): + class EmbeddingModule(nn.Module): + def __init__(self): + super().__init__() + self.embed = nn.Embedding(num_embeddings=10, embedding_dim=3) + + def forward(self, x): + return self.embed(x) + + raw_data = np.random.randint(0, 10, (2, 4)).astype(np.int64) + torch_module = EmbeddingModule().eval() + assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, target, dev) + + +@tvm.testing.parametrize_targets("cuda") +def test_flatten_module(target, dev): + class FlattenModule(nn.Module): + def __init__(self): + super().__init__() + self.flatten = nn.Flatten() + + def forward(self, x): + return self.flatten(x) + + raw_data = np.random.randn(2, 3, 4, 5).astype(np.float32) + torch_module = FlattenModule().eval() + assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, target, dev) + + +@tvm.testing.parametrize_targets("cuda") +def test_numel(target, dev): + class NumelModule(nn.Module): + def forward(self, x): + return torch.tensor(x.numel()) + + raw_data = np.random.randn(2, 3, 4).astype(np.float32) + torch_module = NumelModule().eval() + assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, target, dev) + + +@tvm.testing.parametrize_targets("cuda") +def test_size(target, dev): + class SizeModule(nn.Module): + def forward(self, x): + return torch.tensor(x.size(0)) + + raw_data = np.random.randn(5, 4).astype(np.float32) + torch_module = SizeModule().eval() + assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, target, dev) + + +@tvm.testing.parametrize_targets("cuda") +def test_tensor(target, dev): + class TensorModule(nn.Module): + def forward(self, x): + return torch.tensor([1, 2, 3]) + + raw_data = np.zeros((1,)).astype(np.float32) + torch_module = TensorModule().eval() + assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, target, dev) + + +@tvm.testing.parametrize_targets("cuda") +def test_type(target, dev): + class TypeModule(nn.Module): + def forward(self, x): + return x.type(torch.float16) + + raw_data = np.random.randn(2, 3).astype(np.float32) + torch_module = TypeModule().eval() + assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, target, dev) + + +@tvm.testing.parametrize_targets("cuda") +def test_float(target, dev): + class FloatModule(nn.Module): + def forward(self, x): + return x.float() + + raw_data = np.random.randn(2, 3).astype(np.float32) + torch_module = FloatModule().eval() + assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, target, dev) + + +@tvm.testing.parametrize_targets("cuda") +def test_half(target, dev): + class HalfModule(nn.Module): + def forward(self, x): + return x.half() + + raw_data = np.random.randn(2, 3).astype(np.float32) + torch_module = HalfModule().eval() + assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, target, dev) + + +@tvm.testing.parametrize_targets("cuda") +def test_getattr(target, dev): + class GetAttrModule(nn.Module): + def forward(self, x): + # Use getattr to call the ndimension method. + return torch.tensor(getattr(x, "ndimension")()) + + raw_data = np.random.randn(2, 3, 4).astype(np.float32) + torch_module = GetAttrModule().eval() + assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, target, dev) + + +@tvm.testing.parametrize_targets("cuda") +def test_sym_size_int(target, dev): + class SymSizeIntModule(nn.Module): + def forward(self, x): + return torch.tensor(x.shape[1]) + + raw_data = np.random.randn(2, 3, 4).astype(np.float32) + torch_module = SymSizeIntModule().eval() + assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, target, dev) + + +@tvm.testing.parametrize_targets("cuda") +def test_interpolate(target, dev): + class InterpolateModule(nn.Module): + def forward(self, x): + # Upsample to a fixed size. + return F.interpolate(x, size=(16, 16), mode="nearest") + + raw_data = np.random.randn(2, 3, 8, 8).astype(np.float32) + torch_module = InterpolateModule().eval() + assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, target, dev) + + +@tvm.testing.parametrize_targets("cuda") +def test_cross_entropy_module(target, dev): + class CrossEntropyModule(nn.Module): + def __init__(self): + super().__init__() + self.criterion = nn.CrossEntropyLoss() + self.target = torch.tensor([0, 1, 2, 1]) + + def forward(self, x): + return self.criterion(x, self.target) + + raw_data = np.random.randn(4, 3).astype(np.float32) + torch_module = CrossEntropyModule().eval() + assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, target, dev) + + if __name__ == "__main__": tvm.testing.main() diff --git a/tests/python/relax/test_frontend_from_exported_program.py b/tests/python/relax/test_frontend_from_exported_program.py index f0bb33964ef2..e5d307b895a9 100644 --- a/tests/python/relax/test_frontend_from_exported_program.py +++ b/tests/python/relax/test_frontend_from_exported_program.py @@ -17,6 +17,7 @@ import operator import pytest import torch +from torch import nn from torch.nn import Module from torch.export import export @@ -4977,6 +4978,36 @@ def main( verify_model(Eye2(), example_args2, {}, Expected2) +def test_cross_entropy(): + class CrossEntropyModule(Module): + def __init__(self): + super().__init__() + self.criterion = nn.CrossEntropyLoss() + self.target = torch.tensor([0, 1, 2, 1]) + + def forward(self, x): + return self.criterion(x, self.target) + + @tvm.script.ir_module + class Expected1: + @R.function + def main(x: R.Tensor((4, 3), dtype="float32")) -> R.Tuple(R.Tensor((), dtype="float32")): + with R.dataflow(): + lv: R.Tensor((4, 3), dtype="float32") = R.nn.log_softmax(x, axis=-1) + lv1: R.Tensor((), dtype="float32") = R.nn.nll_loss( + lv, + targets=R.const([0, 1, 2, 1], dtype="int64"), + reduction="mean", + ignore_index=-100, + ) + gv: R.Tuple(R.Tensor((), dtype="float32")) = (lv1,) + R.output(gv) + return gv + + example_args1 = (torch.randn(4, 3, dtype=torch.float32),) + verify_model(CrossEntropyModule(), example_args1, {}, Expected1) + + def test_linspace(): class Linspace(Module): def forward(self, input): @@ -5027,3 +5058,4 @@ def main( if __name__ == "__main__": tvm.testing.main() +1