-
Notifications
You must be signed in to change notification settings - Fork 22
Open
Description
Reproduce snippet
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import torch
import torch.nn as nn
import tempfile
import shutil
import contextlib
import pytest
from pathlib import Path
import nnscaler
import nnscaler.graph.function.function as F
from nnscaler.ir.tensor import IRFullTensor
from nnscaler.graph import IRGraph
from nnscaler.ir.adapter import IRAdapter
from nnscaler.parallel import ComputeConfig, parallelize, build_optimizer
from nnscaler.ir.operator import IRFwOperation, IRDataOperation
from nnscaler.graph.segment import IRSegment
from nnscaler.graph.schedule.predefined import PredefinedSched
from tests.utils import clear_dir_on_rank0, init_random, raises_with_cause
from tests.launch_torchrun import torchrun
from tests.parallel_module.test_gencode import _gencode_contains, print_gencode
class Layer(torch.nn.Module):
def __init__(self):
super().__init__()
self.linear = torch.nn.Linear(4, 4, bias=False)
def forward(self, x, context):
return self.linear(x) + context['bias']
class Model(torch.nn.Module):
def __init__(self, num_layers):
super().__init__()
self.layers = [Layer() for _ in range(num_layers)]
def forward(self, x, content):
context = content['content']
for layer in self.layers:
x = layer(x, context)
x = x.sum()
return x
def pas(graph, cfg):
print('graph', graph.nodes())
dataloader, fc1, gi1, add1, fc2, gi2, add2, loss = graph.nodes()[:8]
graph.staging([fc1, fc2])
stages = graph.select(ntype=IRSegment, flatten=False)
stages = [s for s in stages if s.isfw()]
ngpus = cfg.plan_ngpus
sub_nodes = graph.replicate(dataloader, ngpus)
for i, sub_node in enumerate(sub_nodes):
graph.assign(sub_node, i)
print('stage 0', stages[0].nodes())
print('stage 1', stages[1].nodes())
for node in stages[0].nodes():
graph.assign(node, 0)
for node in stages[1].nodes():
graph.assign(node, 1)
return graph
def test_non_tensor_pp():
m = Model(2)
m.train()
torch.manual_seed(0)
if torch.cuda.is_available():
torch.cuda.manual_seed(0)
x = torch.randn([2, 4], dtype=torch.float32, device=torch.cuda.current_device())
content = {'content': {'bias': torch.randn([4], dtype=torch.float32, device=torch.cuda.current_device())}}
with tempfile.TemporaryDirectory() as tempdir:
parallelize(
m,
{'x': x, 'content': content},
pas,
ComputeConfig(2, 2, use_end2end=True),
reuse='override',
gen_savedir=tempdir,
load_module=False,
)
print_gencode(tempdir, Model, 0)
print_gencode(tempdir, Model, 1)Description
In rank 1's generated code, there is a variable getitem_28 whose adapter is not generated correctly. It is a IRObject with dynamic size.
Reactions are currently unavailable
Metadata
Metadata
Assignees
Labels
No labels