Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions colossalai/shardformer/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,19 +18,19 @@
The sample API usage is given below:

``` python
from colossalai.shardformer.shard.shardmodel import ShardModel
from colossalai.shardformer import shard_model
from transformers import BertForMaskedLM

# create huggingface model as normal
model = BertForMaskedLM.from_pretrained("bert-base-uncased")

# make the huggingface model paralleled to ShardModel
# auto policy:
shardmodel = ShardModel(model).model
sharded_model = shard_model(model)

# custom policy:
from xxx import <POLICYCLASS>
shardmodel = ShardModel(model, <POLICYCLASS>).model
sharded_model = shard_model(model, <POLICYCLASS>)

# do angthing as normal
...
Expand Down
5 changes: 5 additions & 0 deletions colossalai/shardformer/shard/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
from .shard_config import ShardConfig
from .sharder import ModelSharder, shard_model
from .slicer import Slicer

__all__ = ['ShardConfig', 'ModelSharder', 'shard_model', 'Slicer']
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from dataclasses import dataclass

__all__ = ['ShardConfig']


@dataclass
class ShardConfig:
Expand Down
27 changes: 18 additions & 9 deletions colossalai/shardformer/shard/sharder.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,15 @@
import os
from dataclasses import dataclass
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Type, Union
from typing import Any, Callable, Dict, List

import torch
import torch.nn as nn

import colossalai.nn as col_nn
from colossalai.logging import get_dist_logger

from ..policies.autopolicy import get_autopolicy
from ..policies.basepolicy import Layer, Policy
from ..policies.basepolicy import Policy
from ..utils.utils import getattr_, hasattr_, setattr_
from .shardconfig import ShardConfig
from .shard_config import ShardConfig
from .slicer import Slicer

logger = get_dist_logger()
__all__ = ['ModelSharder', 'shard_model']


class ModelSharder(object):
Expand Down Expand Up @@ -245,3 +240,17 @@ def bind_layer(self, model: nn.Module) -> None:
param = nn.Parameter(param)
setattr_(model, k, param)
setattr_(model, v, param)


def shard_model(model: nn.Module, shard_config: ShardConfig = None, policy: Policy = None):
r"""
The function is used to shard the PyTorch model.

Args:
model (`torch.nn.Model`): the origin huggingface model
shard_config (`ShardConfig`): the config for distribute information
policy (`Policy`): the custom policy for sharding
"""
sharder = ModelSharder(model=model, shard_config=shard_config, policy=policy)
sharder.shard()
return model
60 changes: 0 additions & 60 deletions colossalai/shardformer/shard/shardmodel.py

This file was deleted.

7 changes: 1 addition & 6 deletions colossalai/shardformer/shard/slicer.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,7 @@
import os
from dataclasses import dataclass
from typing import Dict, Tuple

import torch
import torch.distributed as dist

from ..policies.basepolicy import Col_Layer, Layer, Row_Layer
from .shardconfig import ShardConfig
from .shard_config import ShardConfig

dim_mapping = {Col_Layer: 1, Row_Layer: 0}

Expand Down
15 changes: 6 additions & 9 deletions colossalai/shardformer/test/test.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,14 @@
import argparse
import inspect
import os

import torch
import torch.nn as nn
from datasets import load_dataset
from torch.utils.data import DataLoader
from tqdm.auto import tqdm
from transformers import AutoTokenizer, BertForMaskedLM, DataCollatorForLanguageModeling, Trainer, TrainingArguments
from transformers import AutoTokenizer, BertForMaskedLM, DataCollatorForLanguageModeling

import colossalai
from colossalai.logging import get_dist_logger
from colossalai.shardformer.shard.shardconfig import ShardConfig
from colossalai.shardformer.shard.shardmodel import ShardModel
from colossalai.shardformer.shard import ShardConfig, shard_model
from colossalai.utils import get_current_device, print_rank_0

os.environ['TRANSFORMERS_NO_ADVISORY_WARNINGS'] = 'true'
Expand Down Expand Up @@ -93,8 +89,9 @@ def train(model: nn.Module, num_epoch: int = 2):
rank=int(str(get_current_device()).split(':')[-1]),
world_size=int(os.environ['WORLD_SIZE']),
)
shardmodel = ShardModel(model, shard_config)
sharded_model = shard_model(model, shard_config)

if args.mode == "train":
train(shardmodel.model)
train(sharded_model)
elif args.mode == "inference":
inference(shardmodel.model)
inference(sharded_model)