From c03d43342227c7075bdd29e0de377e2c9f2de637 Mon Sep 17 00:00:00 2001 From: Mandeep Singh Baines Date: Fri, 5 Feb 2021 16:47:43 -0800 Subject: [PATCH 1/3] [refactor] remove multiprocess dependency on async --- fairscale/nn/pipe/async_pipe.py | 4 ++++ fairscale/nn/pipe/async_schedule.py | 5 +---- fairscale/nn/pipe/multiprocess_pipe.py | 6 +----- fairscale/nn/pipe/multiprocess_pipeline.py | 12 ++++-------- tests/nn/pipe_process/test_pipe.py | 6 ++---- tests/nn/pipe_process/test_transparency.py | 4 ++-- 6 files changed, 14 insertions(+), 23 deletions(-) diff --git a/fairscale/nn/pipe/async_pipe.py b/fairscale/nn/pipe/async_pipe.py index 3da844d21..7cc6ebd52 100644 --- a/fairscale/nn/pipe/async_pipe.py +++ b/fairscale/nn/pipe/async_pipe.py @@ -192,6 +192,8 @@ def __init__( warnings.warn("More ranks than partitions, some ranks unused") self.partitions: List[ModuleWrapper] = [] self.pipeline = None + # TODO(msb) remove this hack + self.partition = None else: self.partitions = self.instantiate_partition(module, self.balance, self.group) if deferred_batch_norm: @@ -200,6 +202,8 @@ def __init__( for name, part in enumerate(self.partitions): self.add_module(str(name), part.module) self.create_pipeline() + # TODO(msb) remove this hack + self.partition = self.partitions[0].module del module diff --git a/fairscale/nn/pipe/async_schedule.py b/fairscale/nn/pipe/async_schedule.py index 1bfd1fef0..166d7bbcc 100644 --- a/fairscale/nn/pipe/async_schedule.py +++ b/fairscale/nn/pipe/async_schedule.py @@ -17,6 +17,7 @@ from .messages import Transport from .microbatch import Batch +from .multiprocess_pipeline import create_task from .skip.tracker import SkipTrackerThroughPotals from .types import EVENT_LOOP_QUEUE, PipeMessage, Tensors @@ -191,10 +192,6 @@ def run_invocation( """Actually run the forward pass for a given module, and send the result to the next stage in the pipeline if needed.""" - # We import here to avoid a cyclic dependency. - # TODO(msb) Break the cyclic dependency. - from .multiprocess_pipeline import create_task - task = create_task( self.checkpoint_stop, batch.index, self.group.rank(), batch, partition.module, skip_trackers, ) diff --git a/fairscale/nn/pipe/multiprocess_pipe.py b/fairscale/nn/pipe/multiprocess_pipe.py index d704daec3..4f1149ce1 100644 --- a/fairscale/nn/pipe/multiprocess_pipe.py +++ b/fairscale/nn/pipe/multiprocess_pipe.py @@ -31,7 +31,6 @@ from fairscale.nn.model_parallel import get_model_parallel_world_size, get_pipeline_parallel_group from . import microbatch -from .async_schedule import Location, ModuleWrapper from .batchnorm import DeferredBatchNorm from .multiprocess_pipeline import MultiProcessPipeline from .phony import get_phony @@ -219,9 +218,6 @@ def __init__( self.add_module(str(0), self.partition) self.create_pipeline() - # TODO(msb) Remove this hack at some point. - self.partitions = [ModuleWrapper(self.partition, Location(self.group.rank(), 0))] - del module def create_pipeline(self) -> None: @@ -229,7 +225,7 @@ def create_pipeline(self) -> None: checkpoint_stop = {"always": self.chunks, "except_last": self.chunks - 1, "never": 0}[self.checkpoint] self.pipeline = MultiProcessPipeline( - [ModuleWrapper(self.partition, Location(self.group.rank(), 0))], + self.partition, self._skip_layout, checkpoint_stop, group=self.group, diff --git a/fairscale/nn/pipe/multiprocess_pipeline.py b/fairscale/nn/pipe/multiprocess_pipeline.py index 1b9e091fc..173027c11 100644 --- a/fairscale/nn/pipe/multiprocess_pipeline.py +++ b/fairscale/nn/pipe/multiprocess_pipeline.py @@ -30,7 +30,6 @@ from fairscale.nn.model_parallel import get_pipeline_parallel_ranks -from .async_schedule import ModuleWrapper from .checkpoint import Checkpointing from .messages import MakeTransport, Transport from .microbatch import Batch @@ -162,7 +161,7 @@ class MultiProcessPipeline: def __init__( self, - partitions: List[ModuleWrapper], + partition: nn.Sequential, skip_layout: SkipLayout, checkpoint_stop: int, group: torch.distributed.ProcessGroup, @@ -171,7 +170,7 @@ def __init__( input_device: Union[None, int, str, torch.device] = None, final_stage: bool = False, ) -> None: - self.partitions = partitions + self.partition = partition self.skip_layout = skip_layout self.__checkpoint_stop = checkpoint_stop self.group = group @@ -187,7 +186,7 @@ def __init__( @property def checkpoint_stop(self) -> int: # Disable checkpointing if in eval mode. - training = self.partitions[0].module.training + training = self.partition.training if not training: return 0 return self.__checkpoint_stop @@ -208,15 +207,12 @@ def run(self, training: bool, batches: List[Batch], event: Optional[Event]) -> N schedule = [(i, self.group.rank()) for i in range(m)] for i, j in schedule: - assert len(self.partitions) == 1 - partition = self.partitions[0] - if self.group.rank() != 0: batch = self.get_batch_from_previous_stage(i, skip_trackers, batches) else: batch = batches[i] - task = create_task(self.checkpoint_stop, i, j, batch, partition.module, skip_trackers) + task = create_task(self.checkpoint_stop, i, j, batch, self.partition, skip_trackers) batches[i] = self.execute_task(task, i, skip_trackers) diff --git a/tests/nn/pipe_process/test_pipe.py b/tests/nn/pipe_process/test_pipe.py index d0ab3f7cd..b6bd38c41 100644 --- a/tests/nn/pipe_process/test_pipe.py +++ b/tests/nn/pipe_process/test_pipe.py @@ -366,7 +366,7 @@ def hook(module, input, output): nonlocal latent latent = output - partition = model.partitions[0] + partition = model.partition partition.module.register_forward_hook(hook) with torch.no_grad(): @@ -616,9 +616,7 @@ def partitions(pipe_class): model = nn.Sequential(a, b) model = pipe_class(model, [1, 1], worker_map=get_worker_map()) - assert isinstance(model.partitions, list) - assert len(model) == 1 - assert isinstance(model.partitions[0].module, nn.Sequential) + assert isinstance(model.partition, nn.Sequential) if model.group.rank() == 0: assert model[0].weight == a.weight diff --git a/tests/nn/pipe_process/test_transparency.py b/tests/nn/pipe_process/test_transparency.py index 262dfc60e..32e361bf9 100644 --- a/tests/nn/pipe_process/test_transparency.py +++ b/tests/nn/pipe_process/test_transparency.py @@ -60,13 +60,13 @@ def zero_grad(parameters): if model.group.rank() == 1: loss = outputs.mean() loss.backward() - grad_with_pipe = sum_grad(model.pipeline.partitions[0].module.parameters()) + grad_with_pipe = sum_grad(model.pipeline.partition.parameters()) # Both grads should be identical. assert torch.allclose(grad_with_pipe, grad_without_pipe[1]) else: model.back_helper(outputs) - grad_with_pipe = sum_grad(model.pipeline.partitions[0].module.parameters()) + grad_with_pipe = sum_grad(model.pipeline.partition.parameters()) # Both grads should be identical. assert torch.allclose(grad_with_pipe, grad_without_pipe[0]) From 335a67435ee5695196f38dc73aac66000a80b626 Mon Sep 17 00:00:00 2001 From: Mandeep Singh Baines Date: Mon, 8 Feb 2021 13:31:35 -0800 Subject: [PATCH 2/3] Fix test --- tests/nn/pipe_process/test_pipe.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/nn/pipe_process/test_pipe.py b/tests/nn/pipe_process/test_pipe.py index b6bd38c41..b14a1befb 100644 --- a/tests/nn/pipe_process/test_pipe.py +++ b/tests/nn/pipe_process/test_pipe.py @@ -367,7 +367,7 @@ def hook(module, input, output): latent = output partition = model.partition - partition.module.register_forward_hook(hook) + partition.register_forward_hook(hook) with torch.no_grad(): model(input) From 867f7034c182e7eba6a3474fa46421bebeea93d0 Mon Sep 17 00:00:00 2001 From: Mandeep Singh Baines Date: Mon, 8 Feb 2021 16:26:16 -0800 Subject: [PATCH 3/3] Fix test failure --- tests/nn/pipe_process/test_transparency.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/nn/pipe_process/test_transparency.py b/tests/nn/pipe_process/test_transparency.py index 32e361bf9..9de481584 100644 --- a/tests/nn/pipe_process/test_transparency.py +++ b/tests/nn/pipe_process/test_transparency.py @@ -60,13 +60,13 @@ def zero_grad(parameters): if model.group.rank() == 1: loss = outputs.mean() loss.backward() - grad_with_pipe = sum_grad(model.pipeline.partition.parameters()) + grad_with_pipe = sum_grad(model.partition.parameters()) # Both grads should be identical. assert torch.allclose(grad_with_pipe, grad_without_pipe[1]) else: model.back_helper(outputs) - grad_with_pipe = sum_grad(model.pipeline.partition.parameters()) + grad_with_pipe = sum_grad(model.partition.parameters()) # Both grads should be identical. assert torch.allclose(grad_with_pipe, grad_without_pipe[0])