Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 7 additions & 9 deletions colossalai/auto_parallel/offload/base_offload_module.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
from typing import Optional, Set
from functools import partial
from typing import Optional, Set

import torch
import torch.nn as nn

from colossalai.nn.parallel.data_parallel import _cast_float
from colossalai.gemini.tensor_utils import free_storage
from colossalai.nn.parallel.data_parallel import _cast_float

from .region_manager import RegionManager
from .util import GlobalRuntimeInfo
Expand All @@ -20,10 +21,7 @@ class BaseOffloadModule:
is_sync (bool): synchronous mode or not.
"""

def __init__(self,
model: nn.Module,
region_manager: RegionManager,
is_sync=True):
def __init__(self, model: nn.Module, region_manager: RegionManager, is_sync=True):

self.model = model
self.region_manager = region_manager
Expand Down Expand Up @@ -69,8 +67,8 @@ def _post_backward(self):
for p in self.model.parameters():
p.grad = None

GlobalRuntimeInfo.fwd_prefetch_event_map.clear()
GlobalRuntimeInfo.bwd_prefetch_event_map.clear()
GlobalRuntimeInfo().fwd_prefetch_event_map.clear()
GlobalRuntimeInfo().bwd_prefetch_event_map.clear()

def grad_handle(self, p, grad):
empty_grad = torch.empty_like(grad)
Expand All @@ -82,7 +80,7 @@ def grad_handle(self, p, grad):
self.overflow_counter += region.has_inf_or_nan
master_stream = torch.cuda.current_stream()
with torch.cuda.stream(self.grad_offload_stream):
GlobalRuntimeInfo.d2h_stream.wait_stream(master_stream)
GlobalRuntimeInfo().d2h_stream.wait_stream(master_stream)
region.move_grad_to_cpu()
return empty_grad

Expand Down
21 changes: 12 additions & 9 deletions colossalai/auto_parallel/offload/mem_optimize.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from typing import Dict

import torch
import torch.fx
from torch.fx import GraphModule
Expand All @@ -7,10 +8,11 @@
from colossalai.fx import ColoTracer, is_compatible_with_meta
from colossalai.fx.passes.meta_info_prop import MetaInfoProp

from .region_manager import RegionManager
from .runtime import runtime_syn_offload_apply_pass, runtime_asyn_offload_apply_pass
from .base_offload_module import BaseOffloadModule
from .util import compute_max_param_mem, compute_total_param_mem, compute_act_peak_mem, GlobalRuntimeInfo
from .region_manager import RegionManager
from .runtime import runtime_asyn_offload_apply_pass, runtime_syn_offload_apply_pass
from .util import GlobalRuntimeInfo, compute_act_peak_mem, compute_max_param_mem, compute_total_param_mem


def memory_optimize(model: torch.nn.Module,
inps: Dict[str, torch.Tensor],
Expand All @@ -29,13 +31,14 @@ def memory_optimize(model: torch.nn.Module,

region_manager = RegionManager(graph, solver_name=solver_name, memory_budget=memory_budget)
region_manager._build_regions()
GlobalRuntimeInfo.region_list = region_manager.region_list
GlobalRuntimeInfo().region_list = region_manager.region_list

act_peak_mem = compute_act_peak_mem(region_manager.region_list) / 1024 ** 2
max_param_mem = compute_max_param_mem(region_manager.region_list) / 1024 ** 2
total_param_mem = compute_total_param_mem(region_manager.region_list) / 1024 ** 2
act_peak_mem = compute_act_peak_mem(region_manager.region_list) / 1024**2
max_param_mem = compute_max_param_mem(region_manager.region_list) / 1024**2
total_param_mem = compute_total_param_mem(region_manager.region_list) / 1024**2
print(
f"act_peak_mem={act_peak_mem:.3f} MB | max_param_mem={max_param_mem:.3f} MB | total_param_mem={total_param_mem:.3f}")
f"act_peak_mem={act_peak_mem:.3f} MB | max_param_mem={max_param_mem:.3f} MB | total_param_mem={total_param_mem:.3f}"
)

if solver_name == 'syn':
gm = runtime_syn_offload_apply_pass(gm, region_manager.region_list)
Expand All @@ -45,5 +48,5 @@ def memory_optimize(model: torch.nn.Module,
raise TypeError(f"Unknown solver name {solver_name}!")

gm.recompile()
optimized_model = BaseOffloadModule(gm, region_manager, solver_name=='syn')
optimized_model = BaseOffloadModule(gm, region_manager, solver_name == 'syn')
return optimized_model
59 changes: 31 additions & 28 deletions colossalai/auto_parallel/offload/runtime.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from typing import List

import torch
from torch.fx.node import Node

Expand All @@ -23,13 +24,13 @@ def forward(ctx, input_, fwd_info, bwd_info):
ctx.bwd_info = bwd_info
d2h_rid = fwd_info.get('d2h_rid', None)
if d2h_rid is not None:
free_region = GlobalRuntimeInfo.region_list[d2h_rid]
free_region = GlobalRuntimeInfo().region_list[d2h_rid]
assert isinstance(free_region, Region)
free_region.free_cuda_data()

h2d_rid = fwd_info.get('h2d_rid', None)
if h2d_rid is not None:
h2d_region = GlobalRuntimeInfo.region_list[h2d_rid]
h2d_region = GlobalRuntimeInfo().region_list[h2d_rid]
assert isinstance(h2d_region, Region)
h2d_region.move_param_to_cuda()

Expand All @@ -40,7 +41,7 @@ def backward(ctx, grad_output):

h2d_rid = ctx.bwd_info.get('h2d_rid', None)
if h2d_rid is not None:
pref_region = GlobalRuntimeInfo.region_list[h2d_rid]
pref_region = GlobalRuntimeInfo().region_list[h2d_rid]
assert isinstance(pref_region, Region)
pref_region.move_param_to_cuda()

Expand All @@ -65,23 +66,22 @@ def forward(ctx, input_, fwd_info, bwd_info):

sync_rid = fwd_info.get('sync_rid', None)
if sync_rid is not None:
prefetch_event = GlobalRuntimeInfo.fwd_prefetch_event_map.get(
sync_rid, None)
prefetch_event = GlobalRuntimeInfo().fwd_prefetch_event_map.get(sync_rid, None)
if prefetch_event:
prefetch_event.wait()

h2d_rid = fwd_info.get('h2d_rid', None)
if h2d_rid is not None:
pref_region = GlobalRuntimeInfo.region_list[h2d_rid]
pref_region = GlobalRuntimeInfo().region_list[h2d_rid]
assert isinstance(pref_region, Region)
master_stream = torch.cuda.current_stream()
with torch.cuda.stream(GlobalRuntimeInfo.h2d_stream):
GlobalRuntimeInfo.h2d_stream.wait_stream(master_stream)
with torch.cuda.stream(GlobalRuntimeInfo().h2d_stream):
GlobalRuntimeInfo().h2d_stream.wait_stream(master_stream)
pref_region.move_param_to_cuda()

prefetch_event = torch.cuda.Event()
prefetch_event.record(GlobalRuntimeInfo.h2d_stream)
GlobalRuntimeInfo.fwd_prefetch_event_map[h2d_rid] = prefetch_event
prefetch_event.record(GlobalRuntimeInfo().h2d_stream)
GlobalRuntimeInfo().fwd_prefetch_event_map[h2d_rid] = prefetch_event

return input_

Expand All @@ -90,27 +90,26 @@ def backward(ctx, grad_output):

sync_rid = ctx.bwd_info.get('sync_rid', None)
if sync_rid is not None:
wait_region = GlobalRuntimeInfo.region_list[sync_rid]
wait_region = GlobalRuntimeInfo().region_list[sync_rid]
assert isinstance(wait_region, Region)
prefetch_event = GlobalRuntimeInfo.bwd_prefetch_event_map.get(
sync_rid, None)
prefetch_event = GlobalRuntimeInfo().bwd_prefetch_event_map.get(sync_rid, None)
if prefetch_event:
prefetch_event.wait()
else:
wait_region.move_param_to_cuda()

h2d_rid = ctx.bwd_info.get('h2d_rid', None)
if h2d_rid is not None:
pref_region = GlobalRuntimeInfo.region_list[h2d_rid]
pref_region = GlobalRuntimeInfo().region_list[h2d_rid]
assert isinstance(pref_region, Region)
master_stream = torch.cuda.current_stream()
with torch.cuda.stream(GlobalRuntimeInfo.h2d_stream):
GlobalRuntimeInfo.h2d_stream.wait_stream(master_stream)
with torch.cuda.stream(GlobalRuntimeInfo().h2d_stream):
GlobalRuntimeInfo().h2d_stream.wait_stream(master_stream)
pref_region.move_param_to_cuda()

prefetch_event = torch.cuda.Event()
prefetch_event.record(GlobalRuntimeInfo.h2d_stream)
GlobalRuntimeInfo.bwd_prefetch_event_map[h2d_rid] = prefetch_event
prefetch_event.record(GlobalRuntimeInfo().h2d_stream)
GlobalRuntimeInfo().bwd_prefetch_event_map[h2d_rid] = prefetch_event
return grad_output, None, None


Expand All @@ -129,6 +128,7 @@ def convert_fwd_upload_bwd_offload_to_action(tensor, fwd_info, bwd_info):
ret = SynPreFwdPostBwdOP.apply(tensor, fwd_info, bwd_info)
return ret


def convert_fwd_prefetch_bwd_offload_to_action(tensor, fwd_info, bwd_info):
'''
Convert Prefetch and Offload operation into runtime action.
Expand Down Expand Up @@ -189,7 +189,8 @@ def runtime_syn_offload_apply_pass(gm: torch.fx.GraphModule, region_list: List[R

if fwd_info or bwd_info:
with mod_graph.inserting_after(last_inp_node):
new_node = mod_graph.create_node('call_function', convert_fwd_upload_bwd_offload_to_action,
new_node = mod_graph.create_node('call_function',
convert_fwd_upload_bwd_offload_to_action,
args=(last_inp_node, fwd_info, bwd_info))
replace_node_users(last_inp_node, new_node)

Expand All @@ -206,11 +207,11 @@ def runtime_asyn_offload_apply_pass(gm: torch.fx.GraphModule, region_list: List[

# upload parameters of the first region
last_inp_node = tuple(mod_graph.nodes)[0]
first_region_with_p = [
region for region in region_list if region.param_size][0]
first_region_with_p = [region for region in region_list if region.param_size][0]
fwd_info = {"h2d_rid": first_region_with_p.r_id}
with mod_graph.inserting_after(last_inp_node):
upload_apply_node = mod_graph.create_node('call_function', convert_fwd_upload_bwd_offload_to_action,
upload_apply_node = mod_graph.create_node('call_function',
convert_fwd_upload_bwd_offload_to_action,
args=(last_inp_node, fwd_info, {}))
replace_node_users(last_inp_node, upload_apply_node)
last_inp_node = upload_apply_node
Expand All @@ -225,19 +226,20 @@ def runtime_asyn_offload_apply_pass(gm: torch.fx.GraphModule, region_list: List[
fwd_info['h2d_rid'] = fwd_prefetch_region.r_id

# forward offload
if r_idx > 0 and region_list[r_idx-1].need_offload:
if r_idx > 0 and region_list[r_idx - 1].need_offload:
fwd_info['d2h_rid'] = r_idx - 1

bwd_info = {}
# backward prefetch
if r_idx > 0 and region_list[r_idx-1].need_offload:
if r_idx > 0 and region_list[r_idx - 1].need_offload:
bwd_info['sync_rid'] = r_idx - 1
if r_idx > 0 and region_list[r_idx-1].bwd_prefetch_region:
bwd_info['h2d_rid'] = region_list[r_idx-1].bwd_prefetch_region.r_id
if r_idx > 0 and region_list[r_idx - 1].bwd_prefetch_region:
bwd_info['h2d_rid'] = region_list[r_idx - 1].bwd_prefetch_region.r_id

if fwd_info or bwd_info:
with mod_graph.inserting_after(last_inp_node):
new_node = mod_graph.create_node('call_function', convert_fwd_prefetch_bwd_offload_to_action,
new_node = mod_graph.create_node('call_function',
convert_fwd_prefetch_bwd_offload_to_action,
args=(last_inp_node, fwd_info, bwd_info))
replace_node_users(last_inp_node, new_node)

Expand All @@ -246,7 +248,8 @@ def runtime_asyn_offload_apply_pass(gm: torch.fx.GraphModule, region_list: List[
if region.bwd_prefetch_region:
bwd_info = {'h2d_rid': region.bwd_prefetch_region.r_id}
with mod_graph.inserting_after(last_inp_node):
new_node = mod_graph.create_node('call_function', convert_fwd_prefetch_bwd_offload_to_action,
new_node = mod_graph.create_node('call_function',
convert_fwd_prefetch_bwd_offload_to_action,
args=(last_inp_node, {}, bwd_info))
replace_node_users(last_inp_node, new_node)
# gm.graph.print_tabular()
Expand Down
33 changes: 21 additions & 12 deletions colossalai/auto_parallel/offload/util.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
from dataclasses import dataclass
from typing import List

import torch

from colossalai.context.singleton_meta import SingletonMeta
from colossalai.fx.profiler import calculate_fwd_out, calculate_fwd_tmp

from .region import Region
Expand All @@ -12,6 +15,7 @@ class NodeInfo:
runtime_fwd_mem: float = 0
runtime_bwd_mem: float = 0


class NvDevicePower:
"""
NVIDIA GPU computing performance (TFLOPs).
Expand All @@ -30,12 +34,14 @@ class NvDevicePower:
A100_FP32 = 19.5


class GlobalRuntimeInfo:
h2d_stream = torch.cuda.Stream()
d2h_stream = torch.cuda.Stream()
fwd_prefetch_event_map = {}
bwd_prefetch_event_map = {}
region_list = []
class GlobalRuntimeInfo(metaclass=SingletonMeta):

def __init__(self):
self.h2d_stream = torch.cuda.Stream()
self.d2h_stream = torch.cuda.Stream()
self.fwd_prefetch_event_map = {}
self.bwd_prefetch_event_map = {}
self.region_list = []


def compute_act_peak_mem(region_list: List[Region]) -> float:
Expand Down Expand Up @@ -70,21 +76,24 @@ def compute_act_peak_mem(region_list: List[Region]) -> float:

return act_peak_mem


def compute_max_param_mem(region_list: List[Region]) -> float:
return max(region.param_size for region in region_list)


def compute_total_param_mem(region_list: List[Region]) -> float:
return sum(region.param_size for region in region_list if region.r_id <= region.shared_rid)


def requires_upload_p_in_fwd(shared_reg: Region):
return (shared_reg.r_id >= shared_reg.shared_rid) or (
shared_reg.r_id < shared_reg.shared_rid and shared_reg.need_offload)
return (shared_reg.r_id >= shared_reg.shared_rid) or (shared_reg.r_id < shared_reg.shared_rid
and shared_reg.need_offload)


def requires_release_p_in_bwd(shared_reg: Region):
return (shared_reg.r_id >= shared_reg.shared_rid) or (
shared_reg.r_id < shared_reg.shared_rid and shared_reg.need_offload)
return (shared_reg.r_id >= shared_reg.shared_rid) or (shared_reg.r_id < shared_reg.shared_rid
and shared_reg.need_offload)


def requires_offload_g_in_bwd(region: Region):
return region.param_size and (region.r_id <= region.shared_rid)


Loading