Skip to content

ops.reshape errors with !fusion->hasDynamicTransform() #418

@kevinstephano

Description

@kevinstephano

Error:

Traceback (most recent call last):
  File "/workspace/simple_dl_models/repro.py", line 22, in <module>
    out = fd.execute(inputs)
  File "/opt/pytorch/nvfuser/nvfuser/__init__.py", line 76, in execute
    result = self._execute(inputs, override_user_schedule)
RuntimeError: !fusion->hasDynamicTransform() INTERNAL ASSERT FAILED at "/opt/pytorch/nvfuser/csrc/kernel_cache.cpp":619, please report a bug to PyTorch. Fusion must be concretized before constructing FusionKernelRuntime

This is a pattern we will find in the GroupNorm operation where the input tensors is reshaped. I am not quite sure what the error indicates since this pattern will have a dynamic transform.

This repro is looking into a future that uses a branch: add_new_reshape. This branch has some new operations not found in TOT in the Python API.

Repro:

import torch
from nvfuser import FusionDefinition, DataType

inputs = [
    torch.randn(256, 128, 28, 28, device='cuda'),
    32 
]

def func(fd: FusionDefinition) :
    T0 = fd.from_pytorch(inputs[0])
    S0 = fd.define_scalar(dtype=DataType.Int)
    V0 = T0.shape()
    T1 = fd.ops.reshape(T0, [V0[0], V0[1] / S0, S0, V0[2], V0[3]])
    var, mean = fd.ops.var_mean(T1, axes=[2, 3, 4], correction=0, keepdim=True)
    fd.add_output(var)
    fd.add_output(mean)

with FusionDefinition() as fd:
    func(fd)

for _ in range(5):
    out = fd.execute(inputs)

Metadata

Metadata

Assignees

Labels

No labels
No labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions