Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 32 additions & 0 deletions examples/images/vit/configs/vit_1d_tp2_ci.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
from colossalai.amp import AMP_TYPE

# hyperparameters
# BATCH_SIZE is as per GPU
# global batch size = BATCH_SIZE x data parallel size
BATCH_SIZE = 8
LEARNING_RATE = 3e-3
WEIGHT_DECAY = 0.3
NUM_EPOCHS = 3
WARMUP_EPOCHS = 1

# model config
IMG_SIZE = 224
PATCH_SIZE = 16
HIDDEN_SIZE = 32
DEPTH = 2
NUM_HEADS = 4
MLP_RATIO = 4
NUM_CLASSES = 10
CHECKPOINT = False
SEQ_LENGTH = (IMG_SIZE // PATCH_SIZE)**2 + 1 # add 1 for cls token

USE_DDP = True
TP_WORLD_SIZE = 2
TP_TYPE = 'row'
parallel = dict(tensor=dict(mode="1d", size=TP_WORLD_SIZE),)

fp16 = dict(mode=AMP_TYPE.NAIVE)
clip_grad_norm = 1.0
gradient_accumulation = 2

LOG_PATH = "./log_ci"
6 changes: 6 additions & 0 deletions examples/images/vit/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,2 +1,8 @@
colossalai >= 0.1.12
torch >= 1.8.1
numpy>=1.24.1
timm>=0.6.12
titans>=0.0.7
tqdm>=4.61.2
transformers>=4.25.1
nvidia-dali-cuda110>=1.8.0 --extra-index-url https://developer.download.nvidia.com/compute/redist
9 changes: 9 additions & 0 deletions examples/images/vit/test_ci.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
export OMP_NUM_THREADS=4

pip install -r requirements.txt

# train
colossalai run \
--nproc_per_node 4 train.py \
--config configs/vit_1d_tp2_ci.py \
--dummy_data
25 changes: 19 additions & 6 deletions examples/images/vit/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from timm.models.vision_transformer import _create_vision_transformer
from titans.dataloader.imagenet import build_dali_imagenet
from tqdm import tqdm
from vit import DummyDataLoader

import colossalai
from colossalai.core import global_context as gpc
Expand Down Expand Up @@ -56,8 +57,8 @@ def init_spec_func(model, tp_type):
def train_imagenet():

parser = colossalai.get_default_parser()
parser.add_argument('--from_torch', default=True, action='store_true')
parser.add_argument('--resume_from', default=False)
parser.add_argument('--resume_from', default=False, action='store_true')
parser.add_argument('--dummy_data', default=False, action='store_true')

args = parser.parse_args()
colossalai.launch_from_torch(config=args.config)
Expand All @@ -74,10 +75,22 @@ def train_imagenet():
logger.log_to_file(log_path)

logger.info('Build data loader', ranks=[0])
root = os.environ['DATA']
train_dataloader, test_dataloader = build_dali_imagenet(root,
train_batch_size=gpc.config.BATCH_SIZE,
test_batch_size=gpc.config.BATCH_SIZE)
if not args.dummy_data:
root = os.environ['DATA']
train_dataloader, test_dataloader = build_dali_imagenet(root,
train_batch_size=gpc.config.BATCH_SIZE,
test_batch_size=gpc.config.BATCH_SIZE)
else:
train_dataloader = DummyDataLoader(length=10,
batch_size=gpc.config.BATCH_SIZE,
category=gpc.config.NUM_CLASSES,
image_size=gpc.config.IMG_SIZE,
return_dict=False)
test_dataloader = DummyDataLoader(length=5,
batch_size=gpc.config.BATCH_SIZE,
category=gpc.config.NUM_CLASSES,
image_size=gpc.config.IMG_SIZE,
return_dict=False)

logger.info('Build model', ranks=[0])

Expand Down
23 changes: 13 additions & 10 deletions examples/images/vit/vit.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,21 +32,24 @@ def __len__(self):


class DummyDataLoader(DummyDataGenerator):
batch_size = 4
channel = 3
category = 8
image_size = 224

def __init__(self, length=10, batch_size=4, channel=3, category=8, image_size=224, return_dict=True):
super().__init__(length)
self.batch_size = batch_size
self.channel = channel
self.category = category
self.image_size = image_size
self.return_dict = return_dict

def generate(self):
image_dict = {}
image_dict['pixel_values'] = torch.rand(DummyDataLoader.batch_size,
DummyDataLoader.channel,
DummyDataLoader.image_size,
DummyDataLoader.image_size,
device=get_current_device()) * 2 - 1
image_dict['label'] = torch.randint(DummyDataLoader.category, (DummyDataLoader.batch_size,),
image_dict['pixel_values'] = torch.rand(
self.batch_size, self.channel, self.image_size, self.image_size, device=get_current_device()) * 2 - 1
image_dict['label'] = torch.randint(self.category, (self.batch_size,),
dtype=torch.int64,
device=get_current_device())
if not self.return_dict:
return image_dict['pixel_values'], image_dict['label']
return image_dict


Expand Down