Skip to content

[Codegen] Fail to capture nested dict for pipeline parallelism #49

@zyeric

Description

@zyeric

Reproduce snippet

#  Copyright (c) Microsoft Corporation.
#  Licensed under the MIT License.

import torch
import torch.nn as nn
import tempfile
import shutil
import contextlib
import pytest
from pathlib import Path


import nnscaler
import nnscaler.graph.function.function as F
from nnscaler.ir.tensor import IRFullTensor
from nnscaler.graph import IRGraph
from nnscaler.ir.adapter import IRAdapter
from nnscaler.parallel import ComputeConfig, parallelize, build_optimizer
from nnscaler.ir.operator import IRFwOperation, IRDataOperation
from nnscaler.graph.segment import IRSegment
from nnscaler.graph.schedule.predefined import PredefinedSched
from tests.utils import clear_dir_on_rank0, init_random, raises_with_cause
from tests.launch_torchrun import torchrun
from tests.parallel_module.test_gencode import _gencode_contains, print_gencode


class Layer(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.linear = torch.nn.Linear(4, 4, bias=False)

    def forward(self, x, context):
        return self.linear(x) + context['bias']


class Model(torch.nn.Module):
    def __init__(self, num_layers):
        super().__init__()
        self.layers = [Layer() for _ in range(num_layers)]

    def forward(self, x, content):
        context = content['content']
        for layer in self.layers:
            x = layer(x, context)
        x = x.sum()
        return x


def pas(graph, cfg):
    print('graph', graph.nodes())
    dataloader, fc1, gi1, add1, fc2, gi2, add2, loss = graph.nodes()[:8]

    graph.staging([fc1, fc2])
    stages = graph.select(ntype=IRSegment, flatten=False)
    stages = [s for s in stages if s.isfw()]

    ngpus = cfg.plan_ngpus
    sub_nodes = graph.replicate(dataloader, ngpus)
    for i, sub_node in enumerate(sub_nodes):
        graph.assign(sub_node, i)

    print('stage 0', stages[0].nodes())
    print('stage 1', stages[1].nodes())
    for node in stages[0].nodes():
        graph.assign(node, 0)

    for node in stages[1].nodes():
        graph.assign(node, 1)
    return graph


def test_non_tensor_pp():
    m = Model(2)
    m.train()
    torch.manual_seed(0)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(0)
    x = torch.randn([2, 4], dtype=torch.float32, device=torch.cuda.current_device())
    content = {'content': {'bias': torch.randn([4], dtype=torch.float32, device=torch.cuda.current_device())}}

    with tempfile.TemporaryDirectory() as tempdir:
        parallelize(
            m,
            {'x': x, 'content': content},
            pas,
            ComputeConfig(2, 2, use_end2end=True),
            reuse='override',
            gen_savedir=tempdir,
            load_module=False,
        )
        print_gencode(tempdir, Model, 0)
        print_gencode(tempdir, Model, 1)

Description

In rank 1's generated code, there is a variable getitem_28 whose adapter is not generated correctly. It is a IRObject with dynamic size.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions