Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
132 changes: 16 additions & 116 deletions examples/tutorial/auto_parallel/auto_parallel_with_resnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,119 +4,37 @@

import torch
from titans.utils import barrier_context
from torch.fx import GraphModule
from torchvision import transforms
from torchvision.datasets import CIFAR10
from torchvision.models import resnet50
from tqdm import tqdm

import colossalai
from colossalai.auto_parallel.passes.runtime_apply_pass import runtime_apply_pass
from colossalai.auto_parallel.passes.runtime_preparation_pass import runtime_preparation_pass
from colossalai.auto_parallel.tensor_shard.solver.cost_graph import CostGraph
from colossalai.auto_parallel.tensor_shard.solver.graph_analysis import GraphAnalyser
from colossalai.auto_parallel.tensor_shard.solver.options import DataloaderOption, SolverOptions
from colossalai.auto_parallel.tensor_shard.solver.solver import Solver
from colossalai.auto_parallel.tensor_shard.solver.strategies_constructor import StrategiesConstructor
from colossalai.auto_parallel.tensor_shard.initialize import autoparallelize
from colossalai.core import global_context as gpc
from colossalai.device.device_mesh import DeviceMesh
from colossalai.fx.tracer.tracer import ColoTracer
from colossalai.logging import get_dist_logger
from colossalai.nn.lr_scheduler import CosineAnnealingLR
from colossalai.utils import get_dataloader

DATA_ROOT = Path(os.environ.get('DATA', '../data')).absolute()


def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument('-s', '--synthetic', action="store_true", help="use synthetic dataset instead of CIFAR10")
return parser.parse_args()


def synthesize_data():
img = torch.rand(gpc.config.BATCH_SIZE, 3, 32, 32)
label = torch.randint(low=0, high=10, size=(gpc.config.BATCH_SIZE,))
return img, label


def main():
args = parse_args()
colossalai.launch_from_torch(config='./config.py')

logger = get_dist_logger()

if not args.synthetic:
with barrier_context():
# build dataloaders
train_dataset = CIFAR10(root=DATA_ROOT,
download=True,
transform=transforms.Compose([
transforms.RandomCrop(size=32, padding=4),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize(mean=[0.4914, 0.4822, 0.4465],
std=[0.2023, 0.1994, 0.2010]),
]))

test_dataset = CIFAR10(root=DATA_ROOT,
train=False,
transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(mean=[0.4914, 0.4822, 0.4465], std=[0.2023, 0.1994, 0.2010]),
]))

train_dataloader = get_dataloader(
dataset=train_dataset,
add_sampler=True,
shuffle=True,
batch_size=gpc.config.BATCH_SIZE,
pin_memory=True,
)

test_dataloader = get_dataloader(
dataset=test_dataset,
add_sampler=True,
batch_size=gpc.config.BATCH_SIZE,
pin_memory=True,
)
else:
train_dataloader, test_dataloader = None, None

# initialize device mesh
physical_mesh_id = torch.arange(0, 4)
mesh_shape = (2, 2)
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True)

# trace the model with meta data
tracer = ColoTracer()
model = resnet50(num_classes=10).cuda()
input_sample = {'x': torch.rand([gpc.config.BATCH_SIZE * torch.distributed.get_world_size(), 3, 32, 32]).to('meta')}
graph = tracer.trace(root=model, meta_args=input_sample)
gm = GraphModule(model, graph, model.__class__.__name__)
gm.recompile()

# prepare info for solver
solver_options = SolverOptions(dataloader_option=DataloaderOption.DISTRIBUTED)
strategies_constructor = StrategiesConstructor(graph, device_mesh, solver_options)
strategies_constructor.build_strategies_and_cost()
cost_graph = CostGraph(strategies_constructor.leaf_strategies)
cost_graph.simplify_graph()
graph_analyser = GraphAnalyser(gm)

# solve the solution
solver = Solver(gm.graph, strategies_constructor, cost_graph, graph_analyser)
ret = solver.call_solver_serialized_args()
solution = list(ret[0])
if gpc.get_global_rank() == 0:
for index, node in enumerate(graph.nodes):
print(node.name, node.strategies_vector[solution[index]].name)

# process the graph for distributed training ability
gm, sharding_spec_dict, origin_spec_dict, comm_actions_dict = runtime_preparation_pass(gm, solution, device_mesh)
gm = runtime_apply_pass(gm)
gm.recompile()

model = autoparallelize(model, input_sample)
# build criterion
criterion = torch.nn.CrossEntropyLoss()

Expand All @@ -127,65 +45,47 @@ def main():
lr_scheduler = CosineAnnealingLR(optimizer, total_steps=gpc.config.NUM_EPOCHS)

for epoch in range(gpc.config.NUM_EPOCHS):
gm.train()
model.train()

if args.synthetic:
# if we use synthetic data
# we assume it only has 30 steps per epoch
num_steps = range(30)
# if we use synthetic data
# we assume it only has 30 steps per epoch
num_steps = range(30)

else:
# we use the actual number of steps for training
num_steps = range(len(train_dataloader))
data_iter = iter(train_dataloader)
progress = tqdm(num_steps)

for _ in progress:
if args.synthetic:
# generate fake data
img, label = synthesize_data()
else:
# get the real data
img, label = next(data_iter)
# generate fake data
img, label = synthesize_data()

img = img.cuda()
label = label.cuda()
optimizer.zero_grad()
output = gm(img, sharding_spec_dict, origin_spec_dict, comm_actions_dict)
output = model(img)
train_loss = criterion(output, label)
train_loss.backward(train_loss)
optimizer.step()
lr_scheduler.step()

# run evaluation
gm.eval()
model.eval()
correct = 0
total = 0

if args.synthetic:
# if we use synthetic data
# we assume it only has 10 steps for evaluation
num_steps = range(30)
# if we use synthetic data
# we assume it only has 10 steps for evaluation
num_steps = range(30)

else:
# we use the actual number of steps for training
num_steps = range(len(test_dataloader))
data_iter = iter(test_dataloader)
progress = tqdm(num_steps)

for _ in progress:
if args.synthetic:
# generate fake data
img, label = synthesize_data()
else:
# get the real data
img, label = next(data_iter)
# generate fake data
img, label = synthesize_data()

img = img.cuda()
label = label.cuda()

with torch.no_grad():
output = gm(img, sharding_spec_dict, origin_spec_dict, comm_actions_dict)
output = model(img)
test_loss = criterion(output, label)
pred = torch.argmax(output, dim=-1)
correct += torch.sum(pred == label)
Expand Down
32 changes: 32 additions & 0 deletions examples/tutorial/auto_parallel/environment.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
name: auto
channels:
- pytorch
- conda-forge
- defaults
dependencies:
- _libgcc_mutex=0.1=conda_forge
- _openmp_mutex=4.5=2_kmp_llvm
- blas=1.0=mkl
- brotlipy=0.7.0=py38h27cfd23_1003
- bzip2=1.0.8=h7b6447c_0
- ca-certificates=2022.12.7=ha878542_0
- certifi=2022.12.7=pyhd8ed1ab_0
- cffi=1.15.1=py38h74dc2b5_0
- charset-normalizer=2.0.4=pyhd3eb1b0_0
- coin-or-cbc=2.10.8=h3786ebc_0
- coin-or-cgl=0.60.6=h6f57e76_2
- coin-or-clp=1.17.7=hc56784d_2
- coin-or-osi=0.108.7=h2720bb7_2
- coin-or-utils=2.11.6=h202d8b1_2
- python=3.8.13
- pip=22.2.2
- cudatoolkit=11.3
- pytorch=1.12.1
- torchvision=0.13.1
- numpy=1.23.1
- pip:
- titans
- torch==1.12.1
- pulp==2.7.0
- datasets
- colossalai
13 changes: 13 additions & 0 deletions examples/tutorial/auto_parallel/setup.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
from setuptools import find_packages, setup

setup(
name='auto_parallel',
version='0.0.1',
description='',
packages=find_packages(),
install_requires=[
'torch',
'numpy',
'tqdm',
],
)
11 changes: 11 additions & 0 deletions examples/tutorial/auto_parallel/test_ci.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
#!/bin/bash
set -euxo pipefail

conda init bash
conda env create -f environment.yaml
conda activate auto
cd ../../..
pip uninstall colossalai
pip install -v .
cd ./examples/tutorial/auto_parallel
colossalai run --nproc_per_node 4 auto_parallel_with_resnet.py -s
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

-s is no longer needed if we only use synthetic dataset. This tutorial is only for user to quickly start a script without any preparation.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure