Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions python/tvm/relax/frontend/torch/base_fx_graph_translator.py
Original file line number Diff line number Diff line change
Expand Up @@ -956,6 +956,17 @@ def _cat(self, node: fx.Node) -> relax.Var:
axis = args[1] if len(node.args) > 1 else node.kwargs.get("dim", 0)
return self.block_builder.emit(relax.op.concat(args[0], axis=axis))

def _chunk(self, node: fx.Node) -> relax.Var:
x = self.env[node.args[0]]
chunks = node.args[1]
dim = node.args[2] if len(node.args) > 2 else node.kwargs.get("dim", 0)
x_shape = self.shape_of(x)
max_chunks = x_shape[dim].value
n_sections = min(chunks, max_chunks)
return self.block_builder.emit(
relax.op.split(x=x, indices_or_sections=n_sections, axis=dim)
)

def _cumsum(self, node: fx.Node) -> relax.Var:
x = self.env[node.args[0]]

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -359,6 +359,7 @@ def create_convert_map(
"argmin.default": self._argmax_argmin(relax.op.argmin),
# tensor manipulation
"cat.default": self._cat,
"chunk.default": self._chunk,
"clamp.Tensor": self._clamp,
"concat.default": self._cat,
"copy_.default": self._copy_,
Expand Down
95 changes: 83 additions & 12 deletions tests/python/relax/test_from_exported_to_cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
# specific language governing permissions and limitations
# under the License.


import tvm
from tvm import relax
import tvm.testing
Expand Down Expand Up @@ -332,6 +331,30 @@ def forward(self, x):
assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, target, dev)


@tvm.testing.parametrize_targets("cuda")
def test_split_sections_list(target, dev):
# Test split using a list of section sizes
batch = 3
channels = 2
height = 10
width = 5
sections = [3, 2, 5]
dim = 2 # split across height
raw_data = np.random.rand(batch, channels, height, width).astype("float32")

class SplitModelSectionsList(nn.Module):
def __init__(self, split_size, dim):
super().__init__()
self.split_size = split_size
self.dim = dim

def forward(self, x):
return torch.split(x, split_size_or_sections=self.split_size, dim=self.dim)

torch_module = SplitModelSectionsList(split_size=sections, dim=dim).eval()
assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, target, dev)


@tvm.testing.parametrize_targets("cuda")
def test_batch_norm0(target, dev):
# Eval, no momentum, no affine, no running stats
Expand Down Expand Up @@ -373,26 +396,74 @@ def test_batch_norm3(target, dev):


@tvm.testing.parametrize_targets("cuda")
def test_split_sections_list(target, dev):
# Test split using a list of section sizes
batch = 3
def test_chunk_even(target, dev):
# Chunks is a divisor of the dimension size
batch = 6
channels = 2
height = 10
height = 3
width = 4
chunks = 3
dim = 0
raw_data = np.random.rand(batch, channels, height, width).astype("float32")

class ChunkModel(nn.Module):
def __init__(self, chunks, dim):
super().__init__()
self.chunks = chunks
self.dim = dim

def forward(self, x):
return x.chunk(self.chunks, dim=self.dim)

torch_module = ChunkModel(chunks=chunks, dim=dim).eval()
assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, target, dev)


@tvm.testing.parametrize_targets("cuda")
def test_chunk_uneven(target, dev):
# Chunks is not a divisor of the dimension size
batch = 2
channels = 5
height = 4
width = 5
sections = [3, 2, 5]
dim = 2 # split across height
chunks = 2
dim = 1
raw_data = np.random.rand(batch, channels, height, width).astype("float32")

class SplitModelSectionsList(nn.Module):
def __init__(self, split_size, dim):
class ChunkModel(nn.Module):
def __init__(self, chunks, dim):
super().__init__()
self.split_size = split_size
self.chunks = chunks
self.dim = dim

def forward(self, x):
return torch.split(x, split_size_or_sections=self.split_size, dim=self.dim)
return x.chunk(self.chunks, dim=self.dim)

torch_module = SplitModelSectionsList(split_size=sections, dim=dim).eval()
torch_module = ChunkModel(chunks=chunks, dim=dim).eval()
assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, target, dev)


@tvm.testing.parametrize_targets("cuda")
def test_chunk_too_many(target, dev):
# If user asks for more chunks than the size of the dim, pytorch simply splits in sections of size 1
batch = 1
channels = 3
height = 2
width = 2
chunks = 99
dim = 1
raw_data = np.random.rand(batch, channels, height, width).astype("float32")

class ChunkModel(nn.Module):
def __init__(self, chunks, dim):
super().__init__()
self.chunks = chunks
self.dim = dim

def forward(self, x):
return x.chunk(self.chunks, dim=self.dim)

torch_module = ChunkModel(chunks=chunks, dim=dim).eval()
assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, target, dev)


Expand Down