Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 2 additions & 3 deletions colossalai/shardformer/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -81,16 +81,15 @@ We will follow this roadmap to develop Shardformer:
- [ ] Hugging Face
- [ ] NLP
- [x] BERT
- [ ] T5
- [ ] LlaMa
- [x] T5
- [x] LlaMa
- [ ] GPT2
- [ ] BLOOM
- [ ] RoBERTa
- [ ] ALBERT
- [ ] ERNIE
- [ ] GPT Neo
- [ ] GPT-J
- [ ] CV
- [ ] CV
- [ ] ViT
- [ ] BEiT
Expand Down
26 changes: 17 additions & 9 deletions colossalai/shardformer/layer/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -469,13 +469,14 @@ def __init__(self,
dtype: torch.dtype = None,
device: torch.device = None,
process_group: ProcessGroup = None,
gather_output: bool = True,
weight_initializer: Callable = init.normal_(),
*args,
**kwargs):
super().__init__()

self.num_embeddings = num_embeddings
self.embed_dim = embedding_dim
self.embedding_dim = embedding_dim
self.process_group = process_group
self.num_partitions = dist.get_world_size(process_group)
self.embed_dim_per_partition = divide(embedding_dim, self.num_partitions)
Expand All @@ -499,7 +500,9 @@ def __init__(self,

@staticmethod
def from_native_module(module: nn.Embedding,
process_group: Union[ProcessGroup, List[ProcessGroup]] = None) -> "Embedding1D":
process_group: Union[ProcessGroup, List[ProcessGroup]] = None,
*args,
**kwargs) -> "Embedding1D":
r"""
Build a 1D parallelized Embedding from a native nn.Embedding module.
"""
Expand Down Expand Up @@ -527,7 +530,9 @@ def from_native_module(module: nn.Embedding,
max_norm=max_norm,
norm_type=norm_type,
scale_grad_by_freq=scale_grad_by_freq,
sparse=sparse)
sparse=sparse,
*args,
**kwargs)

# copy the weight
with torch.no_grad():
Expand All @@ -537,7 +542,7 @@ def from_native_module(module: nn.Embedding,
return embedding

def reset_parameters(self, weight_initializer) -> None:
fan_in, fan_out = self.num_embeddings, self.embed_dim
fan_in, fan_out = self.num_embeddings, self.embedding_dim
weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out)
self._fill_padding_idx_with_zero()

Expand All @@ -548,9 +553,12 @@ def _fill_padding_idx_with_zero(self) -> None:

def forward(self, input_: Tensor) -> Tensor:
output_parallel = F.embedding(input_, self.weight, self.padding_idx, *self.embed_args, **self.embed_kwargs)
output = gather_forward_split_backward(output_parallel, dim=-1, process_group=self.process_group)

return output
if self.gather_output:
output = gather_forward_split_backward(output_parallel, dim=-1, process_group=self.process_group)
return output
else:
return output_parallel


class VocabParallelEmbedding1D(ParallelLayer):
Expand Down Expand Up @@ -595,7 +603,7 @@ def __init__(self,
**kwargs):
super().__init__()
self.num_embeddings = num_embeddings
self.embed_dim = embedding_dim
self.embedding_dim = embedding_dim
self.padding_idx = padding_idx
self.embed_args = args
self.embed_kwargs = kwargs
Expand All @@ -610,7 +618,7 @@ def __init__(self,
self.vocab_end_index = self.vocab_start_index + self.num_embeddings_per_partition

self.weight = Parameter(
torch.empty((self.num_embeddings_per_partition, self.embed_dim), device=device, dtype=dtype))
torch.empty((self.num_embeddings_per_partition, self.embedding_dim), device=device, dtype=dtype))

# offset the seed with randomizer index and rank
seed = torch.random.initial_seed()
Expand Down Expand Up @@ -662,7 +670,7 @@ def _set_tensor_parallel_attributes(self):

def reset_parameters(self, weight_initializer) -> None:
with seed(ParallelMode.TENSOR):
fan_in, fan_out = self.num_embeddings, self.embed_dim
fan_in, fan_out = self.num_embeddings, self.embedding_dim
weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out)
self._fill_padding_idx_with_zero()

Expand Down
6 changes: 6 additions & 0 deletions colossalai/shardformer/policies/autopolicy.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,12 @@ class PolicyLocation:
PolicyLocation(file_name="llama", class_name="LlamaForSequenceClassificationPolicy"),

# T5
"transformers.models.t5.modeling_t5.T5Model":
PolicyLocation(file_name="t5", class_name="T5ModelPolicy"),
"transformers.models.t5.modeling_t5.T5ForConditionalGeneration":
PolicyLocation(file_name="t5", class_name="T5ForConditionalGenerationPolicy"),
"transformers.models.t5.modeling_t5.T5EncoderModel":
PolicyLocation(file_name="t5", class_name="T5EncoderPolicy"),

# GPT2
}
Expand Down
1 change: 1 addition & 0 deletions colossalai/shardformer/policies/basepolicy.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ class SubModuleReplacementDescription:
suffix: str
target_module: ParallelModule
kwargs: Dict[str, Any] = None
ignore_if_not_exist: bool = False


@dataclass
Expand Down
Loading