Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions fairscale/nn/pipe/async_pipe.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,8 @@ def __init__(
warnings.warn("More ranks than partitions, some ranks unused")
self.partitions: List[ModuleWrapper] = []
self.pipeline = None
# TODO(msb) remove this hack
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you give some context about what it should be instead?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

AsyncPipe seems to manage partitions as a list of ModuleWrapper while MultiProcessPipe only manages a single nn.Sequential. I don't totally understand what the List[ModuleWrapper] logic is all about and there don't seem to be any tests for that really. All the tests just inspect the first partition in the list. So this hack allows to simply MultiProcessPipe while allowing the tests to continue functioning.

I'm currently focused on rationalizing MultiProcessPipe. When I move on to AsyncPipe, I'm plan to fully understand what is actually going on here. I don't totally know how to undo this hack just yet. My hope is that I can remove List[ModuleWrapper] and just have a single nn.Sequential. But I currently don't know if that is possible.

self.partition = None
else:
self.partitions = self.instantiate_partition(module, self.balance, self.group)
if deferred_batch_norm:
Expand All @@ -200,6 +202,8 @@ def __init__(
for name, part in enumerate(self.partitions):
self.add_module(str(name), part.module)
self.create_pipeline()
# TODO(msb) remove this hack
self.partition = self.partitions[0].module

del module

Expand Down
5 changes: 1 addition & 4 deletions fairscale/nn/pipe/async_schedule.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

from .messages import Transport
from .microbatch import Batch
from .multiprocess_pipeline import create_task
from .skip.tracker import SkipTrackerThroughPotals
from .types import EVENT_LOOP_QUEUE, PipeMessage, Tensors

Expand Down Expand Up @@ -191,10 +192,6 @@ def run_invocation(
"""Actually run the forward pass for a given module, and send the result
to the next stage in the pipeline if needed."""

# We import here to avoid a cyclic dependency.
# TODO(msb) Break the cyclic dependency.
from .multiprocess_pipeline import create_task

task = create_task(
self.checkpoint_stop, batch.index, self.group.rank(), batch, partition.module, skip_trackers,
)
Expand Down
6 changes: 1 addition & 5 deletions fairscale/nn/pipe/multiprocess_pipe.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@
from fairscale.nn.model_parallel import get_model_parallel_world_size, get_pipeline_parallel_group

from . import microbatch
from .async_schedule import Location, ModuleWrapper
from .batchnorm import DeferredBatchNorm
from .multiprocess_pipeline import MultiProcessPipeline
from .phony import get_phony
Expand Down Expand Up @@ -219,17 +218,14 @@ def __init__(
self.add_module(str(0), self.partition)
self.create_pipeline()

# TODO(msb) Remove this hack at some point.
self.partitions = [ModuleWrapper(self.partition, Location(self.group.rank(), 0))]

del module

def create_pipeline(self) -> None:
# The micro-batch index where the checkpointing stops.
checkpoint_stop = {"always": self.chunks, "except_last": self.chunks - 1, "never": 0}[self.checkpoint]

self.pipeline = MultiProcessPipeline(
[ModuleWrapper(self.partition, Location(self.group.rank(), 0))],
self.partition,
self._skip_layout,
checkpoint_stop,
group=self.group,
Expand Down
12 changes: 4 additions & 8 deletions fairscale/nn/pipe/multiprocess_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@

from fairscale.nn.model_parallel import get_pipeline_parallel_ranks

from .async_schedule import ModuleWrapper
from .checkpoint import Checkpointing
from .messages import MakeTransport, Transport
from .microbatch import Batch
Expand Down Expand Up @@ -162,7 +161,7 @@ class MultiProcessPipeline:

def __init__(
self,
partitions: List[ModuleWrapper],
partition: nn.Sequential,
skip_layout: SkipLayout,
checkpoint_stop: int,
group: torch.distributed.ProcessGroup,
Expand All @@ -171,7 +170,7 @@ def __init__(
input_device: Union[None, int, str, torch.device] = None,
final_stage: bool = False,
) -> None:
self.partitions = partitions
self.partition = partition
self.skip_layout = skip_layout
self.__checkpoint_stop = checkpoint_stop
self.group = group
Expand All @@ -187,7 +186,7 @@ def __init__(
@property
def checkpoint_stop(self) -> int:
# Disable checkpointing if in eval mode.
training = self.partitions[0].module.training
training = self.partition.training
if not training:
return 0
return self.__checkpoint_stop
Expand All @@ -208,15 +207,12 @@ def run(self, training: bool, batches: List[Batch], event: Optional[Event]) -> N
schedule = [(i, self.group.rank()) for i in range(m)]

for i, j in schedule:
assert len(self.partitions) == 1
partition = self.partitions[0]

if self.group.rank() != 0:
batch = self.get_batch_from_previous_stage(i, skip_trackers, batches)
else:
batch = batches[i]

task = create_task(self.checkpoint_stop, i, j, batch, partition.module, skip_trackers)
task = create_task(self.checkpoint_stop, i, j, batch, self.partition, skip_trackers)

batches[i] = self.execute_task(task, i, skip_trackers)

Expand Down
8 changes: 3 additions & 5 deletions tests/nn/pipe_process/test_pipe.py
Original file line number Diff line number Diff line change
Expand Up @@ -366,8 +366,8 @@ def hook(module, input, output):
nonlocal latent
latent = output

partition = model.partitions[0]
partition.module.register_forward_hook(hook)
partition = model.partition
partition.register_forward_hook(hook)

with torch.no_grad():
model(input)
Expand Down Expand Up @@ -616,9 +616,7 @@ def partitions(pipe_class):
model = nn.Sequential(a, b)
model = pipe_class(model, [1, 1], worker_map=get_worker_map())

assert isinstance(model.partitions, list)
assert len(model) == 1
assert isinstance(model.partitions[0].module, nn.Sequential)
assert isinstance(model.partition, nn.Sequential)

if model.group.rank() == 0:
assert model[0].weight == a.weight
Expand Down
4 changes: 2 additions & 2 deletions tests/nn/pipe_process/test_transparency.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,13 +60,13 @@ def zero_grad(parameters):
if model.group.rank() == 1:
loss = outputs.mean()
loss.backward()
grad_with_pipe = sum_grad(model.pipeline.partitions[0].module.parameters())
grad_with_pipe = sum_grad(model.partition.parameters())

# Both grads should be identical.
assert torch.allclose(grad_with_pipe, grad_without_pipe[1])
else:
model.back_helper(outputs)
grad_with_pipe = sum_grad(model.pipeline.partitions[0].module.parameters())
grad_with_pipe = sum_grad(model.partition.parameters())

# Both grads should be identical.
assert torch.allclose(grad_with_pipe, grad_without_pipe[0])
Expand Down