From a1b1ec6ee7c02a369246b3b6449f3c85138e992d Mon Sep 17 00:00:00 2001 From: Yang Yu Date: Thu, 12 Jan 2023 14:11:12 +0800 Subject: [PATCH] Use different tags for acts and grads isend/irecv pairs. --- varuna/pipeline.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/varuna/pipeline.py b/varuna/pipeline.py index 254f0f4..159aa08 100644 --- a/varuna/pipeline.py +++ b/varuna/pipeline.py @@ -16,6 +16,9 @@ import os, sys import time +P2P_TAG_ACTIONS = 1 +P2P_TAG_GRADS = 2 + class Pipeline: """ Pipeline parallelism for Varuna """ @@ -131,7 +134,7 @@ def acts_receiver(self): for d in self.fwd_inp_shape_changes: fwd_inp_shape[d] = self.last_chunk_size acts_tensor = torch.ones(fwd_inp_shape, dtype=dtype) - handle = dist.irecv(acts_tensor, src=self.receive_rank) + handle = dist.irecv(acts_tensor, src=self.receive_rank, tag=P2P_TAG_ACTIONS) recv_handles.put((handle, acts_tensor)) if recv_handles.qsize()>4: handle, tensor = recv_handles.get() @@ -156,7 +159,7 @@ def grads_receiver(self): for d in self.bwd_grad_shape_changes: bwd_grad_shape[d] = self.last_chunk_size grads_tensor = torch.ones(bwd_grad_shape, dtype=dtype) - handle = dist.irecv(grads_tensor, src=self.send_rank) + handle = dist.irecv(grads_tensor, src=self.send_rank, tag=P2P_TAG_GRADS) recv_handles.put((handle, grads_tensor)) if recv_handles.qsize()>4: handle, tensor = recv_handles.get() @@ -176,7 +179,7 @@ def acts_sender(self): send_handles = Queue() while count > 0: output_acts = self.acts_send_queue.get() - handle = dist.isend(output_acts, dst=self.send_rank) + handle = dist.isend(output_acts, dst=self.send_rank, tag=P2P_TAG_ACTIONS) send_handles.put(handle) if send_handles.qsize()>4: handle = send_handles.get() @@ -197,7 +200,7 @@ def grads_sender(self): while count > 0: input_grads = self.grads_send_queue.get() # TODO: why is this contiguous needed ??? - handle = dist.isend(input_grads.contiguous(), dst=self.receive_rank) + handle = dist.isend(input_grads.contiguous(), dst=self.receive_rank, tag=P2P_TAG_GRADS) send_handles.put(handle) if send_handles.qsize()>4: handle = send_handles.get()