-
Notifications
You must be signed in to change notification settings - Fork 22
Open
Description
Env
Python 3.10.14
torch 2.6.0+cu124
nnscaler 0.7
Reproduce
import tempfile
import torch
from nnscaler.graph.parser.register import register_op
from nnscaler.parallel import ComputeConfig, parallelize, build_optimizer
from tests.parallel_module.test_gencode import _gencode_contains, print_gencode
# Note: this annotation is not correct, just for illustration purposes
@register_op('a b -> a b')
def foo(x: torch.Tensor, **kwargs) -> torch.Tensor:
return x + kwargs['y']
class TestModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.linear = torch.nn.Linear(10, 10)
def forward(self, t):
x = self.linear(t)
y = x + 1
z = foo(x, y=y)
return z
def test_kwargs():
m = TestModule()
trace_data = torch.randn([2, 10], dtype=torch.float32)
with tempfile.TemporaryDirectory() as tempdir:
pas_cfg = {
'parallel_profile': False
}
parallelize(
m,
{'t': trace_data},
'dp',
ComputeConfig(1, 1, use_end2end=False, pas_config=pas_cfg),
reuse='override',
gen_savedir=tempdir,
load_module=False,
)
print_gencode(tempdir, TestModule, 0)Explanation: if we pass the tensor as kwargs, the system will not track the tensor correctly. As a result, we will see the following gencode and the runtime throws error then
add_20 = torch.add(linear_34, 1, alpha=1)
del linear_34, add_20
foo_16 = tests.codegen.issue_codegen_dep_42.foo(linear_38, y=add_20)
Expect behavior
- support passing tensor as kwargs
- handle the kwargs case when parsing registered node
- generate code correctly
Reactions are currently unavailable
Metadata
Metadata
Assignees
Labels
No labels