Skip to content

[Relax][Torch] Fixed issues related to sum op when without dim and keep dim#18583

Merged
tlopex merged 5 commits intoapache:mainfrom
locnd182644:relax/torch/sum
Dec 13, 2025
Merged

[Relax][Torch] Fixed issues related to sum op when without dim and keep dim#18583
tlopex merged 5 commits intoapache:mainfrom
locnd182644:relax/torch/sum

Conversation

@locnd182644
Copy link
Contributor

@locnd182644 locnd182644 commented Dec 12, 2025

Issue 1: Without Dim

Summary:

In _sum function (BaseFXGraphImporter), after retrieve_args, args[1] = [] and still pass into relax.op.sum so the result is incorrect.

Steps to Reproduce

  • Module
class SumWithoutDim(nn.Module):
    def forward(self, x):
        return torch.sum(x)
class Module:
    def main(x: R.Tensor((2, 3), dtype="float32")) -> R.Tuple(R.Tensor((2, 3), dtype="float32")):
        with R.dataflow():
            lv: R.Tensor((2, 3), dtype="float32") = R.sum(x, axis=[], keepdims=False)
            gv: R.Tuple(R.Tensor((2, 3), dtype="float32")) = (lv,)
            R.output(gv)
        return gv
  • Result:

Input: tensor([[1., 1., 1.], [1., 1., 1.]])
Torch output: tensor(6.)
Torch output shape: torch.Size([])
TVM output: [[1. 1. 1.] [1. 1. 1.]]
TVM output shape: (2, 3)

Expected

class Module:
    def main(x: R.Tensor((2, 3), dtype="float32")) -> R.Tuple(R.Tensor((), dtype="float32")):
        with R.dataflow():
            lv: R.Tensor((), dtype="float32") = R.sum(x, axis=None, keepdims=False)
            gv: R.Tuple(R.Tensor((), dtype="float32")) = (lv,)
            R.output(gv)
        return gv
  • Result: TVM output: 6.0; TVM output shape: ()

Issue 2: Keep Dim

Summary:

In _sum function (BaseFXGraphImporter), previously keepdim value get only from node.kwargs and no pass into relax.op.sum. Now keepdim get more from args[2] and pass into.

Steps to Reproduce

  • Module
class SumKeepDim(nn.Module):
    def forward(self, x):
        return torch.sum(x, dim=1, keepdim=True)
class Module:
    def main(x: R.Tensor((2, 3), dtype="float32")) -> R.Tuple(R.Tensor((2,), dtype="float32")):
        with R.dataflow():
            lv: R.Tensor((2,), dtype="float32") = R.sum(x, axis=[1], keepdims=False)
            gv: R.Tuple(R.Tensor((2,), dtype="float32")) = (lv,)
            R.output(gv)
        return gv

  • Result:

Input: tensor([[1., 1., 1.], [1., 1., 1.]])
Torch output: tensor([[3.], [3.]])
Torch output shape: torch.Size([2, 1])
TVM VM output: [3. 3.]
TVM VM output shape: (2,)

Expected

class Module:
    def main(x: R.Tensor((2, 3), dtype="float32")) -> R.Tuple(R.Tensor((2, 1), dtype="float32")):
        with R.dataflow():
            lv: R.Tensor((2, 1), dtype="float32") = R.sum(x, axis=[1], keepdims=True)
            gv: R.Tuple(R.Tensor((2, 1), dtype="float32")) = (lv,)
            R.output(gv)
        return gv
  • Result: TVM output: [[3.] [3.]] ;TVM output shape: (2, 1)

- WithoutDim: args[1] = [] and still pass into relax.op.sum -> result incorrect
- KeepDim: Before keepdim value get only from node.kwargs and no pass into relax.op.sum. Now keepdim get more from args[2] and pass into.
@locnd182644
Copy link
Contributor Author

@tvm-bot rerun

@mshr-h
Copy link
Contributor

mshr-h commented Dec 12, 2025

CI is failing due to the PR body contains @R which treated as username tagging.
Please remove them from PR body and rerun it. @locnd182644

usernames: FAILED: PR body must not tag anyone but found these usernames: ['@I', '@R', '@I', '@R', '@I', '@R', '@I', '@R']

@locnd182644
Copy link
Contributor Author

@mshr-h Thank you. I will modify and rerun it now.

@locnd182644
Copy link
Contributor Author

@tvm-bot rerun

Copy link
Member

@tlopex tlopex left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM! Thanks

@tlopex tlopex merged commit 6248b5d into apache:main Dec 13, 2025
10 checks passed
@locnd182644 locnd182644 deleted the relax/torch/sum branch December 31, 2025 06:32
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants